1#pragma once
2
3#include <bitset>
4#include <mutex>
5#include <sstream>
6#include <unordered_map>
7#include <vector>
8
9#include <c10/macros/Macros.h>
10
11#include <torch/csrc/monitor/events.h>
12
13namespace torch {
14namespace monitor {
15
16constexpr int NUM_AGGREGATIONS = 7;
17
18// Aggregation is the list of possible aggregations for Stats.
19// These use bitwise flags so they can be efficiently stored.
20enum class C10_API_ENUM Aggregation {
21 // NONE means no aggregations are set.
22 NONE = 0,
23 // VALUE exports the most recently set value.
24 VALUE = 1,
25 // MEAN computes the mean of the set values within the window. Zero if no
26 // values.
27 MEAN = 2,
28 // COUNT tracks the number of times a value is set within the window.
29 COUNT = 3,
30 // SUM computes the sum of the values set within the window.
31 SUM = 4,
32 // MIN computes the minimum of the values set within the window. Zero if no
33 // values.
34 MAX = 5,
35 // MAX computes the maximum of the values set within the window. Zero if no
36 // values.
37 MIN = 6,
38};
39
40struct TORCH_API AggregationHash {
41 template <typename T>
42 std::size_t operator()(T t) const {
43 return static_cast<std::size_t>(t);
44 }
45};
46
47// aggregationName returns the human readable name corresponding to the
48// aggregation.
49TORCH_API const char* aggregationName(Aggregation agg);
50
51template <typename T>
52class Stat;
53
54namespace {
55template <typename T>
56inline std::bitset<NUM_AGGREGATIONS> merge(T& list) {
57 std::bitset<NUM_AGGREGATIONS> a;
58 for (Aggregation b : list) {
59 a.set(static_cast<int>(b));
60 }
61 return a;
62}
63} // namespace
64
65namespace detail {
66void TORCH_API registerStat(Stat<double>* stat);
67void TORCH_API registerStat(Stat<int64_t>* stat);
68void TORCH_API unregisterStat(Stat<double>* stat);
69void TORCH_API unregisterStat(Stat<int64_t>* stat);
70} // namespace detail
71
72// Stat is used to compute summary statistics in a performant way over fixed
73// intervals. Stat logs the statistics as an Event once every `windowSize`
74// duration. When the window closes the stats are logged via the event handlers
75// as a `torch.monitor.Stat` event.
76//
77// `windowSize` should be set to something relatively high to avoid a huge
78// number of events being logged. Ex: 60s. Stat uses millisecond precision.
79//
80// If maxSamples is set, the stat will cap the number of samples per window by
81// discarding `add` calls once `maxSamples` adds have occurred. If it's not set,
82// all `add` calls during the window will be included.
83// This is an optional field to make aggregations more directly comparable
84// across windows when the number of samples might vary.
85//
86// Stats support double and int64_t data types depending on what needs to be
87// logged and needs to be templatized with one of them.
88//
89// When the Stat is destructed it will log any remaining data even if the window
90// hasn't elapsed.
91template <typename T>
92class Stat {
93 private:
94 struct Values {
95 T value{0};
96 T sum{0};
97 T min{0};
98 T max{0};
99 int64_t count{0};
100 };
101
102 public:
103 Stat(
104 std::string name,
105 std::initializer_list<Aggregation> aggregations,
106 std::chrono::milliseconds windowSize,
107 int64_t maxSamples = std::numeric_limits<int64_t>::max())
108 : name_(std::move(name)),
109 aggregations_(merge(aggregations)),
110 windowSize_(windowSize),
111 maxSamples_(maxSamples) {
112 detail::registerStat(this);
113 }
114
115 Stat(
116 std::string name,
117 std::vector<Aggregation> aggregations,
118 std::chrono::milliseconds windowSize,
119 int64_t maxSamples = std::numeric_limits<int64_t>::max())
120 : name_(std::move(name)),
121 aggregations_(merge(aggregations)),
122 windowSize_(windowSize),
123 maxSamples_(maxSamples) {
124 detail::registerStat(this);
125 }
126
127 virtual ~Stat() {
128 {
129 // on destruction log if there's unlogged data
130 std::lock_guard<std::mutex> guard(mu_);
131 logLocked();
132 }
133 detail::unregisterStat(this);
134 }
135
136 // add adds the value v to the current window.
137 void add(T v) {
138 std::lock_guard<std::mutex> guard(mu_);
139 maybeLogLocked();
140
141 if (alreadyLogged()) {
142 return;
143 }
144
145 if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
146 current_.value = v;
147 }
148 if (aggregations_.test(static_cast<int>(Aggregation::MEAN)) ||
149 aggregations_.test(static_cast<int>(Aggregation::SUM))) {
150 current_.sum += v;
151 }
152
153 if (aggregations_.test(static_cast<int>(Aggregation::MAX))) {
154 if (current_.max < v || current_.count == 0) {
155 current_.max = v;
156 }
157 }
158 if (aggregations_.test(static_cast<int>(Aggregation::MIN))) {
159 if (current_.min > v || current_.count == 0) {
160 current_.min = v;
161 }
162 }
163
164 current_.count += 1;
165 maybeLogLocked();
166 }
167
168 const std::string& name() const noexcept {
169 return name_;
170 }
171
172 // count returns the number of items in the current open window.
173 int64_t count() noexcept {
174 std::lock_guard<std::mutex> guard(mu_);
175
176 return current_.count;
177 }
178
179 std::unordered_map<Aggregation, T, AggregationHash> get() noexcept {
180 std::lock_guard<std::mutex> guard(mu_);
181 return getLocked();
182 }
183
184 protected:
185 virtual uint64_t currentWindowId() const {
186 std::chrono::milliseconds now =
187 std::chrono::duration_cast<std::chrono::milliseconds>(
188 std::chrono::steady_clock::now().time_since_epoch());
189
190 // always returns a currentWindowId of at least 1 to avoid 0 window issues
191 return (now / windowSize_) + 1;
192 }
193
194 private:
195 bool alreadyLogged() {
196 return lastLoggedWindowId_ == currentWindowId();
197 }
198
199 void maybeLogLocked() {
200 auto windowId = currentWindowId();
201 bool shouldLog = windowId_ != windowId || current_.count >= maxSamples_;
202 if (shouldLog && !alreadyLogged()) {
203 logLocked();
204 lastLoggedWindowId_ = windowId_;
205 windowId_ = windowId;
206 }
207 }
208
209 void logLocked() {
210 prev_ = current_;
211 current_ = Values();
212
213 // don't log event if there's no data
214 if (prev_.count == 0) {
215 return;
216 }
217
218 Event e;
219 e.name = "torch.monitor.Stat";
220 e.timestamp = std::chrono::system_clock::now();
221
222 auto stats = getLocked();
223 e.data.reserve(stats.size());
224 for (auto& kv : stats) {
225 std::stringstream key;
226 key << name_;
227 key << ".";
228 key << aggregationName(kv.first);
229 e.data[key.str()] = kv.second;
230 }
231
232 logEvent(e);
233 }
234
235 std::unordered_map<Aggregation, T, AggregationHash> getLocked()
236 const noexcept {
237 std::unordered_map<Aggregation, T, AggregationHash> out;
238 out.reserve(aggregations_.count());
239
240 if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
241 out.emplace(Aggregation::VALUE, prev_.value);
242 }
243 if (aggregations_.test(static_cast<int>(Aggregation::MEAN))) {
244 if (prev_.count == 0) {
245 out.emplace(Aggregation::MEAN, 0);
246 } else {
247 out.emplace(Aggregation::MEAN, prev_.sum / prev_.count);
248 }
249 }
250 if (aggregations_.test(static_cast<int>(Aggregation::COUNT))) {
251 out.emplace(Aggregation::COUNT, prev_.count);
252 }
253 if (aggregations_.test(static_cast<int>(Aggregation::SUM))) {
254 out.emplace(Aggregation::SUM, prev_.sum);
255 }
256 if (aggregations_.test(static_cast<int>(Aggregation::MAX))) {
257 out.emplace(Aggregation::MAX, prev_.max);
258 }
259 if (aggregations_.test(static_cast<int>(Aggregation::MIN))) {
260 out.emplace(Aggregation::MIN, prev_.min);
261 }
262
263 return out;
264 }
265
266 const std::string name_;
267 const std::bitset<NUM_AGGREGATIONS> aggregations_;
268
269 std::mutex mu_;
270 Values current_;
271 Values prev_;
272
273 uint64_t windowId_{0};
274 uint64_t lastLoggedWindowId_{0};
275 const std::chrono::milliseconds windowSize_;
276 const int64_t maxSamples_;
277};
278} // namespace monitor
279} // namespace torch
280