1#include <torch/csrc/lazy/core/metrics.h>
2
3#include <c10/util/Logging.h>
4#include <c10/util/irange.h>
5#include <torch/csrc/lazy/backend/backend_interface.h>
6#include <torch/csrc/lazy/core/config.h>
7#include <torch/csrc/lazy/core/helpers.h>
8#include <torch/csrc/lazy/core/util.h>
9
10#include <algorithm>
11#include <chrono>
12#include <cmath>
13#include <sstream>
14
15namespace torch {
16namespace lazy {
17namespace {
18
19const std::vector<double>* ReadEnvPercentiles() {
20 std::vector<std::string> percentiles_list =
21 StrSplit(FLAGS_torch_lazy_metrics_percentiles, ':');
22 std::unique_ptr<std::vector<double>> metrics_percentiles =
23 std::make_unique<std::vector<double>>();
24 for (auto& pct_str : percentiles_list) {
25 double pct = std::stod(pct_str);
26 TORCH_CHECK(pct > 0.0 && pct < 1.0, "Invalid percentile: ", pct);
27 metrics_percentiles->push_back(pct);
28 }
29 std::sort(metrics_percentiles->begin(), metrics_percentiles->end());
30 return metrics_percentiles.release();
31}
32
33const std::vector<double>& GetPercentiles() {
34 static const std::vector<double>* metrics_percentiles = ReadEnvPercentiles();
35 return *metrics_percentiles;
36}
37
38void EmitMetricInfo(
39 const std::string& name,
40 MetricData* data,
41 std::stringstream* ss) {
42 double accumulator = 0.0;
43 size_t total_samples = 0;
44 std::vector<Sample> samples = data->Samples(&accumulator, &total_samples);
45 (*ss) << "Metric: " << name << std::endl;
46 (*ss) << " TotalSamples: " << total_samples << std::endl;
47 (*ss) << " Accumulator: " << data->Repr(accumulator) << std::endl;
48 if (!samples.empty()) {
49 double total = 0.0;
50 for (auto& sample : samples) {
51 total += sample.value;
52 }
53 int64_t delta_time =
54 samples.back().timestamp_ns - samples.front().timestamp_ns;
55 if (delta_time > 0) {
56 double value_sec = 1e6 * (total / (delta_time / 1000.0));
57 (*ss) << " ValueRate: " << data->Repr(value_sec) << " / second"
58 << std::endl;
59 double count_sec =
60 1e6 * (static_cast<double>(samples.size()) / (delta_time / 1000.0));
61 (*ss) << " Rate: " << count_sec << " / second" << std::endl;
62 }
63 }
64
65 const std::vector<double>& metrics_percentiles = GetPercentiles();
66 std::sort(
67 samples.begin(), samples.end(), [](const Sample& s1, const Sample& s2) {
68 return s1.value < s2.value;
69 });
70 (*ss) << " Percentiles: ";
71 for (const auto i : c10::irange(metrics_percentiles.size())) {
72 size_t index = metrics_percentiles[i] * samples.size();
73 if (i > 0) {
74 (*ss) << "; ";
75 }
76 (*ss) << (metrics_percentiles[i] * 100.0)
77 << "%=" << data->Repr(samples[index].value);
78 }
79 (*ss) << std::endl;
80}
81
82void EmitCounterInfo(
83 const std::string& name,
84 CounterData* data,
85 std::stringstream* ss) {
86 (*ss) << "Counter: " << name << std::endl;
87 (*ss) << " Value: " << data->Value() << std::endl;
88}
89
90template <typename T, typename G>
91const typename T::mapped_type& MapInsert(
92 T* cont,
93 const typename T::key_type& key,
94 const G& gen) {
95 auto it = cont->find(key);
96 if (it == cont->end()) {
97 it = cont->emplace(key, gen()).first;
98 }
99 return it->second;
100}
101
102} // namespace
103
104MetricsArena* MetricsArena::Get() {
105 static MetricsArena* arena = new MetricsArena();
106 return arena;
107}
108
109void MetricsArena::ResetCounters() {
110 for (auto& pair : counters_) {
111 if (pair.second) {
112 pair.second->Reset();
113 }
114 }
115}
116
117void MetricsArena::ResetMetrics() {
118 for (auto& pair : metrics_) {
119 if (pair.second) {
120 pair.second->Reset();
121 }
122 }
123}
124
125void MetricsArena::RegisterMetric(
126 const std::string& name,
127 MetricReprFn repr_fn,
128 size_t max_samples,
129 std::shared_ptr<MetricData>* data) {
130 std::lock_guard<std::mutex> lock(lock_);
131 if (*data == nullptr) {
132 *data = MapInsert(&metrics_, name, [&]() {
133 return std::make_shared<MetricData>(std::move(repr_fn), max_samples);
134 });
135 }
136}
137
138void MetricsArena::RegisterCounter(
139 const std::string& name,
140 std::shared_ptr<CounterData>* data) {
141 std::lock_guard<std::mutex> lock(lock_);
142 if (*data == nullptr) {
143 *data = MapInsert(
144 &counters_, name, []() { return std::make_shared<CounterData>(); });
145 }
146}
147
148void MetricsArena::ForEachMetric(
149 const std::function<void(const std::string&, MetricData*)>& metric_func) {
150 std::lock_guard<std::mutex> lock(lock_);
151 for (auto& name_data : metrics_) {
152 if (!name_data.second->IsValid()) {
153 continue;
154 }
155 metric_func(name_data.first, name_data.second.get());
156 }
157}
158
159void MetricsArena::ForEachCounter(
160 const std::function<void(const std::string&, CounterData*)>& counter_func) {
161 std::lock_guard<std::mutex> lock(lock_);
162 for (auto& name_data : counters_) {
163 if (!name_data.second->IsValid())
164 continue;
165 counter_func(name_data.first, name_data.second.get());
166 }
167}
168
169std::vector<std::string> MetricsArena::GetMetricNames() {
170 std::vector<std::string> names;
171 ForEachMetric([&names](const std::string& name, MetricData* data) {
172 names.push_back(name);
173 });
174 return names;
175}
176
177MetricData* MetricsArena::GetMetric(const std::string& name) {
178 std::lock_guard<std::mutex> lock(lock_);
179 auto it = metrics_.find(name);
180 if (it == metrics_.end()) {
181 return nullptr;
182 }
183 return it->second->IsValid() ? it->second.get() : nullptr;
184}
185
186std::vector<std::string> MetricsArena::GetCounterNames() {
187 std::vector<std::string> names;
188 ForEachCounter([&names](const std::string& name, CounterData* data) {
189 names.push_back(name);
190 });
191 return names;
192}
193
194CounterData* MetricsArena::GetCounter(const std::string& name) {
195 std::lock_guard<std::mutex> lock(lock_);
196 auto it = counters_.find(name);
197 if (it == counters_.end()) {
198 return nullptr;
199 }
200 return it->second->IsValid() ? it->second.get() : nullptr;
201}
202
203MetricData::MetricData(MetricReprFn repr_fn, size_t max_samples)
204 : repr_fn_(std::move(repr_fn)), samples_(max_samples) {}
205
206void MetricData::AddSample(int64_t timestamp_ns, double value) {
207 std::lock_guard<std::mutex> lock(lock_);
208 size_t position = count_ % samples_.size();
209 ++count_;
210 accumulator_ += value;
211 samples_[position] = Sample(timestamp_ns, value);
212}
213
214double MetricData::Accumulator() const {
215 std::lock_guard<std::mutex> lock(lock_);
216 return accumulator_;
217}
218
219size_t MetricData::TotalSamples() const {
220 std::lock_guard<std::mutex> lock(lock_);
221 return count_;
222}
223
224std::vector<Sample> MetricData::Samples(
225 double* accumulator,
226 size_t* total_samples) const {
227 std::lock_guard<std::mutex> lock(lock_);
228 std::vector<Sample> samples;
229 if (count_ <= samples_.size()) {
230 samples.insert(samples.end(), samples_.begin(), samples_.begin() + count_);
231 } else {
232 size_t position = count_ % samples_.size();
233 samples.insert(samples.end(), samples_.begin() + position, samples_.end());
234 samples.insert(
235 samples.end(), samples_.begin(), samples_.begin() + position);
236 }
237 if (accumulator != nullptr) {
238 *accumulator = accumulator_;
239 }
240 if (total_samples != nullptr) {
241 *total_samples = count_;
242 }
243 return samples;
244}
245
246void MetricData::Reset() {
247 std::lock_guard<std::mutex> lock(lock_);
248 count_ = 0;
249 // Don't clear. samples_ are init with placeholders.
250 samples_ = std::vector<Sample>(samples_.size());
251 accumulator_ = 0.0;
252}
253
254Metric::Metric(std::string name, MetricReprFn repr_fn, size_t max_samples)
255 : name_(std::move(name)),
256 repr_fn_(std::move(repr_fn)),
257 max_samples_(
258 max_samples != 0 ? max_samples : FLAGS_torch_lazy_metrics_samples),
259 data_(nullptr) {}
260
261double Metric::Accumulator() const {
262 return GetData()->Accumulator();
263}
264
265void Metric::AddSample(int64_t timestamp_ns, double value) {
266 GetData()->AddSample(timestamp_ns, value);
267}
268
269void Metric::AddSample(double value) {
270 GetData()->AddSample(NowNs(), value);
271}
272
273std::vector<Sample> Metric::Samples(double* accumulator, size_t* total_samples)
274 const {
275 return GetData()->Samples(accumulator, total_samples);
276}
277
278std::string Metric::Repr(double value) const {
279 return GetData()->Repr(value);
280}
281
282MetricData* Metric::GetData() const {
283 MetricData* data = data_.load();
284 if (C10_UNLIKELY(data == nullptr)) {
285 // The RegisterMetric() API is a synchronization point, and even if multiple
286 // threads enters it, the data will be created only once.
287 MetricsArena* arena = MetricsArena::Get();
288 arena->RegisterMetric(name_, repr_fn_, max_samples_, &data_ptr_);
289 // Even if multiple threads will enter this IF statement, they will all
290 // fetch the same value, and hence store the same value below.
291 data = data_ptr_.get();
292 data_.store(data);
293 }
294 return data;
295}
296
297Counter::Counter(std::string name) : name_(std::move(name)), data_(nullptr) {}
298
299CounterData* Counter::GetData() const {
300 CounterData* data = data_.load();
301 if (C10_UNLIKELY(data == nullptr)) {
302 // The RegisterCounter() API is a synchronization point, and even if
303 // multiple threads enters it, the data will be created only once.
304 MetricsArena* arena = MetricsArena::Get();
305 arena->RegisterCounter(name_, &data_ptr_);
306 // Even if multiple threads will enter this IF statement, they will all
307 // fetch the same value, and hence store the same value below.
308 data = data_ptr_.get();
309 data_.store(data);
310 }
311 return data;
312}
313
314std::string MetricFnValue(double value) {
315 std::stringstream ss;
316 ss.precision(2);
317 ss << std::fixed << value;
318 return ss.str();
319}
320
321std::string MetricFnBytes(double value) {
322 static const std::array<const char*, 6> kSizeSuffixes{
323 "B", "KB", "MB", "GB", "TB", "PB"};
324 int sfix = 0;
325 for (; (sfix + 1) < kSizeSuffixes.size() && value >= 1024.0; ++sfix) {
326 value /= 1024.0;
327 }
328 std::stringstream ss;
329 ss.precision(2);
330 ss << std::fixed << value << kSizeSuffixes[sfix];
331 return ss.str();
332}
333
334std::string MetricFnTime(double value) {
335 struct TimePart {
336 const char* suffix;
337 double scaler;
338 int width;
339 int precision;
340 char fill;
341 };
342 static const std::array<TimePart, 6> time_parts{
343 {{"d", 86400.0 * 1e9, 2, 0, '0'},
344 {"h", 3600.0 * 1e9, 2, 0, '0'},
345 {"m", 60.0 * 1e9, 2, 0, '0'},
346 {"s", 1e9, 2, 0, '0'},
347 {"ms", 1e6, 3, 0, '0'},
348 {"us", 1e3, 7, 3, '0'}}};
349 int count = 0;
350 std::stringstream ss;
351 for (const auto i : c10::irange(time_parts.size())) {
352 const TimePart& part = time_parts[i];
353 double ctime = value / part.scaler;
354 if (ctime >= 1.0 || count > 0 || i + 1 == time_parts.size()) {
355 ss.precision(part.precision);
356 ss.width(part.width);
357 ss.fill(part.fill);
358 ss << std::fixed << ctime << part.suffix;
359 value -= std::floor(ctime) * part.scaler;
360 ++count;
361 }
362 }
363 return ss.str();
364}
365
366std::string CreateMetricReport() {
367 MetricsArena* arena = MetricsArena::Get();
368 std::stringstream ss;
369 arena->ForEachMetric([&ss](const std::string& name, MetricData* data) {
370 EmitMetricInfo(name, data, &ss);
371 });
372 arena->ForEachCounter([&ss](const std::string& name, CounterData* data) {
373 EmitCounterInfo(name, data, &ss);
374 });
375
376 // Append the backend metrics report
377 ss << getBackend()->CreateMetricReport();
378 return ss.str();
379}
380
381std::string CreateMetricReport(
382 const std::vector<std::string>& counter_names,
383 const std::vector<std::string>& metric_names) {
384 MetricsArena* arena = MetricsArena::Get();
385 std::stringstream ss;
386 std::set<std::string> metric_name_set(
387 metric_names.begin(), metric_names.end());
388 arena->ForEachMetric(
389 [&ss, &metric_name_set](const std::string& name, MetricData* data) {
390 if (metric_name_set.find(name) != metric_name_set.end()) {
391 EmitMetricInfo(name, data, &ss);
392 }
393 });
394 std::set<std::string> counter_name_set(
395 counter_names.begin(), counter_names.end());
396 arena->ForEachCounter(
397 [&ss, &counter_name_set](const std::string& name, CounterData* data) {
398 if (counter_name_set.find(name) != counter_name_set.end()) {
399 EmitCounterInfo(name, data, &ss);
400 }
401 });
402
403 static std::string fall_back_counter_prefix = "aten::";
404 arena->ForEachCounter([&ss](const std::string& name, CounterData* data) {
405 if (name.rfind(fall_back_counter_prefix, 0) == 0) {
406 // it might emit duplicated counter if user also specified exact aten
407 // counter in the `counter_names` but it should be very rare.
408 EmitCounterInfo(name, data, &ss);
409 }
410 });
411 return ss.str();
412}
413
414std::vector<std::string> GetMetricNames() {
415 return MetricsArena::Get()->GetMetricNames();
416}
417
418MetricData* GetMetric(const std::string& name) {
419 return MetricsArena::Get()->GetMetric(name);
420}
421
422std::vector<std::string> GetCounterNames() {
423 return MetricsArena::Get()->GetCounterNames();
424}
425
426CounterData* GetCounter(const std::string& name) {
427 return MetricsArena::Get()->GetCounter(name);
428}
429
430int64_t NowNs() {
431 auto now = std::chrono::high_resolution_clock::now();
432 return std::chrono::duration_cast<std::chrono::nanoseconds>(
433 now.time_since_epoch())
434 .count();
435}
436
437} // namespace lazy
438} // namespace torch
439