1 | #pragma once |
2 | |
3 | #include <bitset> |
4 | #include <mutex> |
5 | #include <sstream> |
6 | #include <unordered_map> |
7 | #include <vector> |
8 | |
9 | #include <c10/macros/Macros.h> |
10 | |
11 | #include <torch/csrc/monitor/events.h> |
12 | |
13 | namespace torch { |
14 | namespace monitor { |
15 | |
16 | constexpr int NUM_AGGREGATIONS = 7; |
17 | |
18 | // Aggregation is the list of possible aggregations for Stats. |
19 | // These use bitwise flags so they can be efficiently stored. |
20 | enum class C10_API_ENUM Aggregation { |
21 | // NONE means no aggregations are set. |
22 | NONE = 0, |
23 | // VALUE exports the most recently set value. |
24 | VALUE = 1, |
25 | // MEAN computes the mean of the set values within the window. Zero if no |
26 | // values. |
27 | MEAN = 2, |
28 | // COUNT tracks the number of times a value is set within the window. |
29 | COUNT = 3, |
30 | // SUM computes the sum of the values set within the window. |
31 | SUM = 4, |
32 | // MIN computes the minimum of the values set within the window. Zero if no |
33 | // values. |
34 | MAX = 5, |
35 | // MAX computes the maximum of the values set within the window. Zero if no |
36 | // values. |
37 | MIN = 6, |
38 | }; |
39 | |
40 | struct TORCH_API AggregationHash { |
41 | template <typename T> |
42 | std::size_t operator()(T t) const { |
43 | return static_cast<std::size_t>(t); |
44 | } |
45 | }; |
46 | |
47 | // aggregationName returns the human readable name corresponding to the |
48 | // aggregation. |
49 | TORCH_API const char* aggregationName(Aggregation agg); |
50 | |
51 | template <typename T> |
52 | class Stat; |
53 | |
54 | namespace { |
55 | template <typename T> |
56 | inline std::bitset<NUM_AGGREGATIONS> merge(T& list) { |
57 | std::bitset<NUM_AGGREGATIONS> a; |
58 | for (Aggregation b : list) { |
59 | a.set(static_cast<int>(b)); |
60 | } |
61 | return a; |
62 | } |
63 | } // namespace |
64 | |
65 | namespace detail { |
66 | void TORCH_API registerStat(Stat<double>* stat); |
67 | void TORCH_API registerStat(Stat<int64_t>* stat); |
68 | void TORCH_API unregisterStat(Stat<double>* stat); |
69 | void TORCH_API unregisterStat(Stat<int64_t>* stat); |
70 | } // namespace detail |
71 | |
72 | // Stat is used to compute summary statistics in a performant way over fixed |
73 | // intervals. Stat logs the statistics as an Event once every `windowSize` |
74 | // duration. When the window closes the stats are logged via the event handlers |
75 | // as a `torch.monitor.Stat` event. |
76 | // |
77 | // `windowSize` should be set to something relatively high to avoid a huge |
78 | // number of events being logged. Ex: 60s. Stat uses millisecond precision. |
79 | // |
80 | // If maxSamples is set, the stat will cap the number of samples per window by |
81 | // discarding `add` calls once `maxSamples` adds have occurred. If it's not set, |
82 | // all `add` calls during the window will be included. |
83 | // This is an optional field to make aggregations more directly comparable |
84 | // across windows when the number of samples might vary. |
85 | // |
86 | // Stats support double and int64_t data types depending on what needs to be |
87 | // logged and needs to be templatized with one of them. |
88 | // |
89 | // When the Stat is destructed it will log any remaining data even if the window |
90 | // hasn't elapsed. |
91 | template <typename T> |
92 | class Stat { |
93 | private: |
94 | struct Values { |
95 | T value{0}; |
96 | T sum{0}; |
97 | T min{0}; |
98 | T max{0}; |
99 | int64_t count{0}; |
100 | }; |
101 | |
102 | public: |
103 | Stat( |
104 | std::string name, |
105 | std::initializer_list<Aggregation> aggregations, |
106 | std::chrono::milliseconds windowSize, |
107 | int64_t maxSamples = std::numeric_limits<int64_t>::max()) |
108 | : name_(std::move(name)), |
109 | aggregations_(merge(aggregations)), |
110 | windowSize_(windowSize), |
111 | maxSamples_(maxSamples) { |
112 | detail::registerStat(this); |
113 | } |
114 | |
115 | Stat( |
116 | std::string name, |
117 | std::vector<Aggregation> aggregations, |
118 | std::chrono::milliseconds windowSize, |
119 | int64_t maxSamples = std::numeric_limits<int64_t>::max()) |
120 | : name_(std::move(name)), |
121 | aggregations_(merge(aggregations)), |
122 | windowSize_(windowSize), |
123 | maxSamples_(maxSamples) { |
124 | detail::registerStat(this); |
125 | } |
126 | |
127 | virtual ~Stat() { |
128 | { |
129 | // on destruction log if there's unlogged data |
130 | std::lock_guard<std::mutex> guard(mu_); |
131 | logLocked(); |
132 | } |
133 | detail::unregisterStat(this); |
134 | } |
135 | |
136 | // add adds the value v to the current window. |
137 | void add(T v) { |
138 | std::lock_guard<std::mutex> guard(mu_); |
139 | maybeLogLocked(); |
140 | |
141 | if (alreadyLogged()) { |
142 | return; |
143 | } |
144 | |
145 | if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) { |
146 | current_.value = v; |
147 | } |
148 | if (aggregations_.test(static_cast<int>(Aggregation::MEAN)) || |
149 | aggregations_.test(static_cast<int>(Aggregation::SUM))) { |
150 | current_.sum += v; |
151 | } |
152 | |
153 | if (aggregations_.test(static_cast<int>(Aggregation::MAX))) { |
154 | if (current_.max < v || current_.count == 0) { |
155 | current_.max = v; |
156 | } |
157 | } |
158 | if (aggregations_.test(static_cast<int>(Aggregation::MIN))) { |
159 | if (current_.min > v || current_.count == 0) { |
160 | current_.min = v; |
161 | } |
162 | } |
163 | |
164 | current_.count += 1; |
165 | maybeLogLocked(); |
166 | } |
167 | |
168 | const std::string& name() const noexcept { |
169 | return name_; |
170 | } |
171 | |
172 | // count returns the number of items in the current open window. |
173 | int64_t count() noexcept { |
174 | std::lock_guard<std::mutex> guard(mu_); |
175 | |
176 | return current_.count; |
177 | } |
178 | |
179 | std::unordered_map<Aggregation, T, AggregationHash> get() noexcept { |
180 | std::lock_guard<std::mutex> guard(mu_); |
181 | return getLocked(); |
182 | } |
183 | |
184 | protected: |
185 | virtual uint64_t currentWindowId() const { |
186 | std::chrono::milliseconds now = |
187 | std::chrono::duration_cast<std::chrono::milliseconds>( |
188 | std::chrono::steady_clock::now().time_since_epoch()); |
189 | |
190 | // always returns a currentWindowId of at least 1 to avoid 0 window issues |
191 | return (now / windowSize_) + 1; |
192 | } |
193 | |
194 | private: |
195 | bool alreadyLogged() { |
196 | return lastLoggedWindowId_ == currentWindowId(); |
197 | } |
198 | |
199 | void maybeLogLocked() { |
200 | auto windowId = currentWindowId(); |
201 | bool shouldLog = windowId_ != windowId || current_.count >= maxSamples_; |
202 | if (shouldLog && !alreadyLogged()) { |
203 | logLocked(); |
204 | lastLoggedWindowId_ = windowId_; |
205 | windowId_ = windowId; |
206 | } |
207 | } |
208 | |
209 | void logLocked() { |
210 | prev_ = current_; |
211 | current_ = Values(); |
212 | |
213 | // don't log event if there's no data |
214 | if (prev_.count == 0) { |
215 | return; |
216 | } |
217 | |
218 | Event e; |
219 | e.name = "torch.monitor.Stat" ; |
220 | e.timestamp = std::chrono::system_clock::now(); |
221 | |
222 | auto stats = getLocked(); |
223 | e.data.reserve(stats.size()); |
224 | for (auto& kv : stats) { |
225 | std::stringstream key; |
226 | key << name_; |
227 | key << "." ; |
228 | key << aggregationName(kv.first); |
229 | e.data[key.str()] = kv.second; |
230 | } |
231 | |
232 | logEvent(e); |
233 | } |
234 | |
235 | std::unordered_map<Aggregation, T, AggregationHash> getLocked() |
236 | const noexcept { |
237 | std::unordered_map<Aggregation, T, AggregationHash> out; |
238 | out.reserve(aggregations_.count()); |
239 | |
240 | if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) { |
241 | out.emplace(Aggregation::VALUE, prev_.value); |
242 | } |
243 | if (aggregations_.test(static_cast<int>(Aggregation::MEAN))) { |
244 | if (prev_.count == 0) { |
245 | out.emplace(Aggregation::MEAN, 0); |
246 | } else { |
247 | out.emplace(Aggregation::MEAN, prev_.sum / prev_.count); |
248 | } |
249 | } |
250 | if (aggregations_.test(static_cast<int>(Aggregation::COUNT))) { |
251 | out.emplace(Aggregation::COUNT, prev_.count); |
252 | } |
253 | if (aggregations_.test(static_cast<int>(Aggregation::SUM))) { |
254 | out.emplace(Aggregation::SUM, prev_.sum); |
255 | } |
256 | if (aggregations_.test(static_cast<int>(Aggregation::MAX))) { |
257 | out.emplace(Aggregation::MAX, prev_.max); |
258 | } |
259 | if (aggregations_.test(static_cast<int>(Aggregation::MIN))) { |
260 | out.emplace(Aggregation::MIN, prev_.min); |
261 | } |
262 | |
263 | return out; |
264 | } |
265 | |
266 | const std::string name_; |
267 | const std::bitset<NUM_AGGREGATIONS> aggregations_; |
268 | |
269 | std::mutex mu_; |
270 | Values current_; |
271 | Values prev_; |
272 | |
273 | uint64_t windowId_{0}; |
274 | uint64_t lastLoggedWindowId_{0}; |
275 | const std::chrono::milliseconds windowSize_; |
276 | const int64_t maxSamples_; |
277 | }; |
278 | } // namespace monitor |
279 | } // namespace torch |
280 | |