1/**
2 * This file is adapted from PyTorch/XLA
3 * https://github.com/pytorch/xla/blob/master/third_party/xla_client/metrics.h
4 */
5
6#pragma once
7
8#include <atomic>
9#include <functional>
10#include <map>
11#include <memory>
12#include <mutex>
13#include <string>
14#include <vector>
15
16#include <c10/macros/Export.h>
17
18namespace torch {
19namespace lazy {
20
21struct TORCH_API Sample {
22 Sample() = default;
23 Sample(int64_t timestamp_ns, double value)
24 : timestamp_ns(timestamp_ns), value(value) {}
25
26 int64_t timestamp_ns = 0;
27 double value = 0;
28};
29
30using MetricReprFn = std::function<std::string(double)>;
31
32// Class used to collect time-stamped numeric samples. The samples are stored in
33// a circular buffer whose size can be configured at constructor time.
34class TORCH_API MetricData {
35 public:
36 // Creates a new MetricData object with the internal circular buffer storing
37 // max_samples samples. The repr_fn argument allow to specify a function which
38 // pretty-prints a sample value.
39 MetricData(MetricReprFn repr_fn, size_t max_samples);
40
41 // Returns the total values of all the samples being posted to this metric.
42 double Accumulator() const;
43
44 size_t TotalSamples() const;
45
46 void AddSample(int64_t timestamp_ns, double value);
47
48 // Returns a vector with all the current samples, from the oldest to the
49 // newer. If accumulator is not nullptr, it will receive the current value of
50 // the metrics' accumulator (the sum of all posted values). If total_samples
51 // is not nullptr, it will receive the count of the posted values.
52 std::vector<Sample> Samples(double* accumulator, size_t* total_samples) const;
53
54 std::string Repr(double value) const {
55 return repr_fn_(value);
56 }
57
58 void Reset();
59
60 bool IsValid() const {
61 return TotalSamples() > 0;
62 }
63
64 private:
65 mutable std::mutex lock_;
66 MetricReprFn repr_fn_;
67 size_t count_ = 0;
68 std::vector<Sample> samples_;
69 double accumulator_ = 0.0;
70};
71
72// Counters are a very lightweight form of metrics which do not need to track
73// sample time.
74class TORCH_API CounterData {
75 public:
76 CounterData() : value_(0) {}
77
78 void AddValue(int64_t value) {
79 value_ += value;
80 }
81
82 int64_t Value() const {
83 return value_;
84 }
85
86 void Reset() {
87 value_ = 0;
88 }
89
90 bool IsValid() const {
91 return value_ > 0;
92 }
93
94 private:
95 std::atomic<int64_t> value_;
96};
97
98class TORCH_API MetricsArena {
99 public:
100 static MetricsArena* Get();
101
102 void ResetCounters();
103 void ResetMetrics();
104
105 // Registers a new metric in the global arena.
106 void RegisterMetric(
107 const std::string& name,
108 MetricReprFn repr_fn,
109 size_t max_samples,
110 std::shared_ptr<MetricData>* data);
111
112 void RegisterCounter(
113 const std::string& name,
114 std::shared_ptr<CounterData>* data);
115
116 void ForEachMetric(
117 const std::function<void(const std::string&, MetricData*)>& metric_func);
118
119 void ForEachCounter(
120 const std::function<void(const std::string&, CounterData*)>&
121 counter_func);
122
123 std::vector<std::string> GetMetricNames();
124
125 MetricData* GetMetric(const std::string& name);
126
127 std::vector<std::string> GetCounterNames();
128
129 CounterData* GetCounter(const std::string& name);
130
131 private:
132 std::mutex lock_;
133 std::map<std::string, std::shared_ptr<MetricData>> metrics_;
134 std::map<std::string, std::shared_ptr<CounterData>> counters_;
135};
136
137// Emits the value in a to_string() conversion.
138TORCH_API std::string MetricFnValue(double value);
139// Emits the value in a humanized bytes representation.
140TORCH_API std::string MetricFnBytes(double value);
141// Emits the value in a humanized time representation. The value is expressed in
142// nanoseconds EPOCH time.
143TORCH_API std::string MetricFnTime(double value);
144
145// The typical use of a Metric is one in which it gets created either in a
146// global scope context:
147// static Metric* metric = new Metric("RpcCount");
148// Or within a function scope:
149// void MyFunction(...) {
150// static Metric* metric = new Metric("RpcCount");
151// ...
152// metric->AddSample(ts_nanos, some_value);
153// }
154class TORCH_API Metric {
155 public:
156 explicit Metric(
157 std::string name,
158 MetricReprFn repr_fn = MetricFnValue,
159 size_t max_samples = 0);
160
161 const std::string& Name() const {
162 return name_;
163 }
164
165 double Accumulator() const;
166
167 void AddSample(int64_t timestamp_ns, double value);
168
169 void AddSample(double value);
170
171 std::vector<Sample> Samples(double* accumulator, size_t* total_samples) const;
172
173 std::string Repr(double value) const;
174
175 private:
176 MetricData* GetData() const;
177
178 std::string name_;
179 MetricReprFn repr_fn_;
180 size_t max_samples_;
181 mutable std::shared_ptr<MetricData> data_ptr_;
182 mutable std::atomic<MetricData*> data_;
183};
184
185// A Counter is a lightweight form of metric which tracks an integer value which
186// can increase or decrease.
187// A typical use is as:
188// static Counter* counter = new Counter("MyCounter");
189// ...
190// counter->AddValue(+1);
191class TORCH_API Counter {
192 public:
193 explicit Counter(std::string name);
194
195 void AddValue(int64_t value) {
196 GetData()->AddValue(value);
197 }
198
199 int64_t Value() const {
200 return GetData()->Value();
201 }
202
203 private:
204 CounterData* GetData() const;
205
206 std::string name_;
207 mutable std::shared_ptr<CounterData> data_ptr_;
208 mutable std::atomic<CounterData*> data_;
209};
210
211#define TORCH_LAZY_COUNTER(name, value) \
212 do { \
213 static ::torch::lazy::Counter* __counter = \
214 new ::torch::lazy::Counter(name); \
215 __counter->AddValue(value); \
216 } while (0)
217
218#define TORCH_LAZY_FN_COUNTER(ns) TORCH_LAZY_COUNTER(c10::str(ns, __func__), 1)
219
220#define TORCH_LAZY_VALUE_METRIC(name, value) \
221 do { \
222 static ::torch::lazy::Metric* __metric = \
223 new ::torch::lazy::Metric(name, torch::lazy::MetricFnValue); \
224 __metric->AddSample(value); \
225 } while (0)
226
227// Creates a report with the current metrics statistics.
228TORCH_API std::string CreateMetricReport();
229
230// Creates a report with the selected metrics statistics.
231TORCH_API std::string CreateMetricReport(
232 const std::vector<std::string>& counter_names,
233 const std::vector<std::string>& metric_names);
234
235// Returns the currently registered metric names. Note that the list can grow
236// since metrics are usually function intialized (they are static function
237// variables).
238TORCH_API std::vector<std::string> GetMetricNames();
239
240// Retrieves the metric data of a given metric, or nullptr if such metric does
241// not exist.
242TORCH_API MetricData* GetMetric(const std::string& name);
243
244// Returns the currently registered counter names. Note that the list can grow
245// since counters are usually function intialized (they are static function
246// variables).
247TORCH_API std::vector<std::string> GetCounterNames();
248
249// Retrieves the counter data of a given counter, or nullptr if such counter
250// does not exist.
251TORCH_API CounterData* GetCounter(const std::string& name);
252
253// Retrieves the current EPOCH time in nanoseconds.
254TORCH_API int64_t NowNs();
255
256// Scope based utility class TORCH_API to measure the time the code takes within
257// a given C++ scope.
258class TORCH_API TimedSection {
259 public:
260 explicit TimedSection(Metric* metric) : metric_(metric), start_(NowNs()) {}
261
262 ~TimedSection() {
263 int64_t now = NowNs();
264 metric_->AddSample(now, now - start_);
265 }
266
267 double Elapsed() const {
268 return 1e-9 * static_cast<double>(NowNs() - start_);
269 }
270
271 private:
272 Metric* metric_;
273 int64_t start_;
274};
275
276#define TORCH_LAZY_TIMED(name) \
277 static torch::lazy::Metric* timed_metric = \
278 new torch::lazy::Metric(name, torch::lazy::MetricFnTime); \
279 torch::lazy::TimedSection timed_section(timed_metric)
280
281} // namespace lazy
282} // namespace torch
283