1#include <torch/csrc/monitor/counters.h>
2#include <torch/csrc/monitor/events.h>
3
4#include <sstream>
5#include <unordered_set>
6
7namespace torch {
8namespace monitor {
9
10const char* aggregationName(Aggregation agg) {
11 switch (agg) {
12 case Aggregation::NONE:
13 return "none";
14 case Aggregation::VALUE:
15 return "value";
16 case Aggregation::MEAN:
17 return "mean";
18 case Aggregation::COUNT:
19 return "count";
20 case Aggregation::SUM:
21 return "sum";
22 case Aggregation::MAX:
23 return "max";
24 case Aggregation::MIN:
25 return "min";
26 default:
27 throw std::runtime_error(
28 "unknown aggregation: " + std::to_string(static_cast<int>(agg)));
29 }
30}
31
32namespace {
33struct Stats {
34 std::mutex mu;
35
36 std::unordered_set<Stat<double>*> doubles;
37 std::unordered_set<Stat<int64_t>*> int64s;
38};
39
40Stats& stats() {
41 static Stats stats;
42 return stats;
43}
44} // namespace
45
46namespace detail {
47void registerStat(Stat<double>* stat) {
48 std::lock_guard<std::mutex> guard(stats().mu);
49
50 stats().doubles.insert(stat);
51}
52void registerStat(Stat<int64_t>* stat) {
53 std::lock_guard<std::mutex> guard(stats().mu);
54
55 stats().int64s.insert(stat);
56}
57void unregisterStat(Stat<double>* stat) {
58 std::lock_guard<std::mutex> guard(stats().mu);
59
60 stats().doubles.erase(stat);
61}
62void unregisterStat(Stat<int64_t>* stat) {
63 std::lock_guard<std::mutex> guard(stats().mu);
64
65 stats().int64s.erase(stat);
66}
67} // namespace detail
68
69} // namespace monitor
70} // namespace torch
71