1 | #include <torch/csrc/monitor/counters.h> |
---|---|
2 | #include <torch/csrc/monitor/events.h> |
3 | |
4 | #include <sstream> |
5 | #include <unordered_set> |
6 | |
7 | namespace torch { |
8 | namespace monitor { |
9 | |
10 | const char* aggregationName(Aggregation agg) { |
11 | switch (agg) { |
12 | case Aggregation::NONE: |
13 | return "none"; |
14 | case Aggregation::VALUE: |
15 | return "value"; |
16 | case Aggregation::MEAN: |
17 | return "mean"; |
18 | case Aggregation::COUNT: |
19 | return "count"; |
20 | case Aggregation::SUM: |
21 | return "sum"; |
22 | case Aggregation::MAX: |
23 | return "max"; |
24 | case Aggregation::MIN: |
25 | return "min"; |
26 | default: |
27 | throw std::runtime_error( |
28 | "unknown aggregation: "+ std::to_string(static_cast<int>(agg))); |
29 | } |
30 | } |
31 | |
32 | namespace { |
33 | struct Stats { |
34 | std::mutex mu; |
35 | |
36 | std::unordered_set<Stat<double>*> doubles; |
37 | std::unordered_set<Stat<int64_t>*> int64s; |
38 | }; |
39 | |
40 | Stats& stats() { |
41 | static Stats stats; |
42 | return stats; |
43 | } |
44 | } // namespace |
45 | |
46 | namespace detail { |
47 | void registerStat(Stat<double>* stat) { |
48 | std::lock_guard<std::mutex> guard(stats().mu); |
49 | |
50 | stats().doubles.insert(stat); |
51 | } |
52 | void registerStat(Stat<int64_t>* stat) { |
53 | std::lock_guard<std::mutex> guard(stats().mu); |
54 | |
55 | stats().int64s.insert(stat); |
56 | } |
57 | void unregisterStat(Stat<double>* stat) { |
58 | std::lock_guard<std::mutex> guard(stats().mu); |
59 | |
60 | stats().doubles.erase(stat); |
61 | } |
62 | void unregisterStat(Stat<int64_t>* stat) { |
63 | std::lock_guard<std::mutex> guard(stats().mu); |
64 | |
65 | stats().int64s.erase(stat); |
66 | } |
67 | } // namespace detail |
68 | |
69 | } // namespace monitor |
70 | } // namespace torch |
71 |