1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | Licensed under the Apache License, Version 2.0 (the "License"); |
3 | you may not use this file except in compliance with the License. |
4 | You may obtain a copy of the License at |
5 | |
6 | http://www.apache.org/licenses/LICENSE-2.0 |
7 | |
8 | Unless required by applicable law or agreed to in writing, software |
9 | distributed under the License is distributed on an "AS IS" BASIS, |
10 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
11 | See the License for the specific language governing permissions and |
12 | limitations under the License. |
13 | ==============================================================================*/ |
14 | |
15 | #include "tensorflow/tsl/platform/numbers.h" |
16 | |
17 | #include <ctype.h> |
18 | #include <float.h> |
19 | #include <stdio.h> |
20 | #include <stdlib.h> |
21 | |
22 | #include <algorithm> |
23 | #include <cinttypes> |
24 | #include <cmath> |
25 | #include <cstdint> |
26 | #include <locale> |
27 | #include <unordered_map> |
28 | |
29 | #include "double-conversion/double-conversion.h" |
30 | #include "tensorflow/tsl/platform/str_util.h" |
31 | #include "tensorflow/tsl/platform/logging.h" |
32 | #include "tensorflow/tsl/platform/macros.h" |
33 | #include "tensorflow/tsl/platform/stringprintf.h" |
34 | #include "tensorflow/tsl/platform/types.h" |
35 | |
36 | namespace tsl { |
37 | |
38 | namespace { |
39 | |
40 | template <typename T> |
41 | const std::unordered_map<std::string, T>* GetSpecialNumsSingleton() { |
42 | static const std::unordered_map<std::string, T>* special_nums = |
43 | CHECK_NOTNULL((new const std::unordered_map<std::string, T>{ |
44 | {"inf" , std::numeric_limits<T>::infinity()}, |
45 | {"+inf" , std::numeric_limits<T>::infinity()}, |
46 | {"-inf" , -std::numeric_limits<T>::infinity()}, |
47 | {"infinity" , std::numeric_limits<T>::infinity()}, |
48 | {"+infinity" , std::numeric_limits<T>::infinity()}, |
49 | {"-infinity" , -std::numeric_limits<T>::infinity()}, |
50 | {"nan" , std::numeric_limits<T>::quiet_NaN()}, |
51 | {"+nan" , std::numeric_limits<T>::quiet_NaN()}, |
52 | {"-nan" , -std::numeric_limits<T>::quiet_NaN()}, |
53 | })); |
54 | return special_nums; |
55 | } |
56 | |
57 | template <typename T> |
58 | T locale_independent_strtonum(const char* str, const char** endptr) { |
59 | auto special_nums = GetSpecialNumsSingleton<T>(); |
60 | std::stringstream s(str); |
61 | |
62 | // Check if str is one of the special numbers. |
63 | std::string special_num_str; |
64 | s >> special_num_str; |
65 | |
66 | for (size_t i = 0; i < special_num_str.length(); ++i) { |
67 | special_num_str[i] = |
68 | std::tolower(special_num_str[i], std::locale::classic()); |
69 | } |
70 | |
71 | auto entry = special_nums->find(special_num_str); |
72 | if (entry != special_nums->end()) { |
73 | *endptr = str + (s.eof() ? static_cast<std::iostream::pos_type>(strlen(str)) |
74 | : s.tellg()); |
75 | return entry->second; |
76 | } else { |
77 | // Perhaps it's a hex number |
78 | if (special_num_str.compare(0, 2, "0x" ) == 0 || |
79 | special_num_str.compare(0, 3, "-0x" ) == 0) { |
80 | return strtol(str, const_cast<char**>(endptr), 16); |
81 | } |
82 | } |
83 | // Reset the stream |
84 | s.str(str); |
85 | s.clear(); |
86 | // Use the "C" locale |
87 | s.imbue(std::locale::classic()); |
88 | |
89 | T result; |
90 | s >> result; |
91 | |
92 | // Set to result to what strto{f,d} functions would have returned. If the |
93 | // number was outside the range, the stringstream sets the fail flag, but |
94 | // returns the +/-max() value, whereas strto{f,d} functions return +/-INF. |
95 | if (s.fail()) { |
96 | if (result == std::numeric_limits<T>::max() || |
97 | result == std::numeric_limits<T>::infinity()) { |
98 | result = std::numeric_limits<T>::infinity(); |
99 | s.clear(s.rdstate() & ~std::ios::failbit); |
100 | } else if (result == -std::numeric_limits<T>::max() || |
101 | result == -std::numeric_limits<T>::infinity()) { |
102 | result = -std::numeric_limits<T>::infinity(); |
103 | s.clear(s.rdstate() & ~std::ios::failbit); |
104 | } |
105 | } |
106 | |
107 | if (endptr) { |
108 | *endptr = |
109 | str + |
110 | (s.fail() ? static_cast<std::iostream::pos_type>(0) |
111 | : (s.eof() ? static_cast<std::iostream::pos_type>(strlen(str)) |
112 | : s.tellg())); |
113 | } |
114 | return result; |
115 | } |
116 | |
117 | static inline const double_conversion::StringToDoubleConverter& |
118 | StringToFloatConverter() { |
119 | static const double_conversion::StringToDoubleConverter converter( |
120 | double_conversion::StringToDoubleConverter::ALLOW_LEADING_SPACES | |
121 | double_conversion::StringToDoubleConverter::ALLOW_HEX | |
122 | double_conversion::StringToDoubleConverter::ALLOW_TRAILING_SPACES | |
123 | double_conversion::StringToDoubleConverter::ALLOW_CASE_INSENSIBILITY, |
124 | 0., 0., "inf" , "nan" ); |
125 | return converter; |
126 | } |
127 | |
128 | } // namespace |
129 | |
130 | namespace strings { |
131 | |
132 | size_t FastInt32ToBufferLeft(int32_t i, char* buffer) { |
133 | uint32_t u = i; |
134 | size_t length = 0; |
135 | if (i < 0) { |
136 | *buffer++ = '-'; |
137 | ++length; |
138 | // We need to do the negation in modular (i.e., "unsigned") |
139 | // arithmetic; MSVC++ apparently warns for plain "-u", so |
140 | // we write the equivalent expression "0 - u" instead. |
141 | u = 0 - u; |
142 | } |
143 | length += FastUInt32ToBufferLeft(u, buffer); |
144 | return length; |
145 | } |
146 | |
147 | size_t FastUInt32ToBufferLeft(uint32_t i, char* buffer) { |
148 | char* start = buffer; |
149 | do { |
150 | *buffer++ = ((i % 10) + '0'); |
151 | i /= 10; |
152 | } while (i > 0); |
153 | *buffer = 0; |
154 | std::reverse(start, buffer); |
155 | return buffer - start; |
156 | } |
157 | |
158 | size_t FastInt64ToBufferLeft(int64_t i, char* buffer) { |
159 | uint64_t u = i; |
160 | size_t length = 0; |
161 | if (i < 0) { |
162 | *buffer++ = '-'; |
163 | ++length; |
164 | u = 0 - u; |
165 | } |
166 | length += FastUInt64ToBufferLeft(u, buffer); |
167 | return length; |
168 | } |
169 | |
170 | size_t FastUInt64ToBufferLeft(uint64_t i, char* buffer) { |
171 | char* start = buffer; |
172 | do { |
173 | *buffer++ = ((i % 10) + '0'); |
174 | i /= 10; |
175 | } while (i > 0); |
176 | *buffer = 0; |
177 | std::reverse(start, buffer); |
178 | return buffer - start; |
179 | } |
180 | |
181 | static const double kDoublePrecisionCheckMax = DBL_MAX / 1.000000000000001; |
182 | |
183 | size_t DoubleToBuffer(double value, char* buffer) { |
184 | // DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all |
185 | // platforms these days. Just in case some system exists where DBL_DIG |
186 | // is significantly larger -- and risks overflowing our buffer -- we have |
187 | // this assert. |
188 | static_assert(DBL_DIG < 20, "DBL_DIG is too big" ); |
189 | |
190 | if (std::isnan(value)) { |
191 | int snprintf_result = snprintf(buffer, kFastToBufferSize, "%snan" , |
192 | std::signbit(value) ? "-" : "" ); |
193 | // Paranoid check to ensure we don't overflow the buffer. |
194 | DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); |
195 | return snprintf_result; |
196 | } |
197 | |
198 | if (std::abs(value) <= kDoublePrecisionCheckMax) { |
199 | int snprintf_result = |
200 | snprintf(buffer, kFastToBufferSize, "%.*g" , DBL_DIG, value); |
201 | |
202 | // The snprintf should never overflow because the buffer is significantly |
203 | // larger than the precision we asked for. |
204 | DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); |
205 | |
206 | if (locale_independent_strtonum<double>(buffer, nullptr) == value) { |
207 | // Round-tripping the string to double works; we're done. |
208 | return snprintf_result; |
209 | } |
210 | // else: full precision formatting needed. Fall through. |
211 | } |
212 | |
213 | int snprintf_result = |
214 | snprintf(buffer, kFastToBufferSize, "%.*g" , DBL_DIG + 2, value); |
215 | |
216 | // Should never overflow; see above. |
217 | DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); |
218 | |
219 | return snprintf_result; |
220 | } |
221 | |
222 | namespace { |
223 | char SafeFirstChar(StringPiece str) { |
224 | if (str.empty()) return '\0'; |
225 | return str[0]; |
226 | } |
227 | void SkipSpaces(StringPiece* str) { |
228 | while (isspace(SafeFirstChar(*str))) str->remove_prefix(1); |
229 | } |
230 | } // namespace |
231 | |
232 | bool safe_strto64(StringPiece str, int64_t* value) { |
233 | SkipSpaces(&str); |
234 | |
235 | int64_t vlimit = kint64max; |
236 | int sign = 1; |
237 | if (absl::ConsumePrefix(&str, "-" )) { |
238 | sign = -1; |
239 | // Different limit for positive and negative integers. |
240 | vlimit = kint64min; |
241 | } |
242 | |
243 | if (!isdigit(SafeFirstChar(str))) return false; |
244 | |
245 | int64_t result = 0; |
246 | if (sign == 1) { |
247 | do { |
248 | int digit = SafeFirstChar(str) - '0'; |
249 | if ((vlimit - digit) / 10 < result) { |
250 | return false; |
251 | } |
252 | result = result * 10 + digit; |
253 | str.remove_prefix(1); |
254 | } while (isdigit(SafeFirstChar(str))); |
255 | } else { |
256 | do { |
257 | int digit = SafeFirstChar(str) - '0'; |
258 | if ((vlimit + digit) / 10 > result) { |
259 | return false; |
260 | } |
261 | result = result * 10 - digit; |
262 | str.remove_prefix(1); |
263 | } while (isdigit(SafeFirstChar(str))); |
264 | } |
265 | |
266 | SkipSpaces(&str); |
267 | if (!str.empty()) return false; |
268 | |
269 | *value = result; |
270 | return true; |
271 | } |
272 | |
273 | bool safe_strtou64(StringPiece str, uint64_t* value) { |
274 | SkipSpaces(&str); |
275 | if (!isdigit(SafeFirstChar(str))) return false; |
276 | |
277 | uint64_t result = 0; |
278 | do { |
279 | int digit = SafeFirstChar(str) - '0'; |
280 | if ((kuint64max - digit) / 10 < result) { |
281 | return false; |
282 | } |
283 | result = result * 10 + digit; |
284 | str.remove_prefix(1); |
285 | } while (isdigit(SafeFirstChar(str))); |
286 | |
287 | SkipSpaces(&str); |
288 | if (!str.empty()) return false; |
289 | |
290 | *value = result; |
291 | return true; |
292 | } |
293 | |
294 | bool safe_strto32(StringPiece str, int32_t* value) { |
295 | SkipSpaces(&str); |
296 | |
297 | int64_t vmax = kint32max; |
298 | int sign = 1; |
299 | if (absl::ConsumePrefix(&str, "-" )) { |
300 | sign = -1; |
301 | // Different max for positive and negative integers. |
302 | ++vmax; |
303 | } |
304 | |
305 | if (!isdigit(SafeFirstChar(str))) return false; |
306 | |
307 | int64_t result = 0; |
308 | do { |
309 | result = result * 10 + SafeFirstChar(str) - '0'; |
310 | if (result > vmax) { |
311 | return false; |
312 | } |
313 | str.remove_prefix(1); |
314 | } while (isdigit(SafeFirstChar(str))); |
315 | |
316 | SkipSpaces(&str); |
317 | |
318 | if (!str.empty()) return false; |
319 | |
320 | *value = static_cast<int32_t>(result * sign); |
321 | return true; |
322 | } |
323 | |
324 | bool safe_strtou32(StringPiece str, uint32_t* value) { |
325 | SkipSpaces(&str); |
326 | if (!isdigit(SafeFirstChar(str))) return false; |
327 | |
328 | int64_t result = 0; |
329 | do { |
330 | result = result * 10 + SafeFirstChar(str) - '0'; |
331 | if (result > kuint32max) { |
332 | return false; |
333 | } |
334 | str.remove_prefix(1); |
335 | } while (isdigit(SafeFirstChar(str))); |
336 | |
337 | SkipSpaces(&str); |
338 | if (!str.empty()) return false; |
339 | |
340 | *value = static_cast<uint32_t>(result); |
341 | return true; |
342 | } |
343 | |
344 | bool safe_strtof(StringPiece str, float* value) { |
345 | int processed_characters_count = -1; |
346 | auto len = str.size(); |
347 | |
348 | // If string length exceeds buffer size or int max, fail. |
349 | if (len >= kFastToBufferSize) return false; |
350 | if (len > std::numeric_limits<int>::max()) return false; |
351 | |
352 | *value = StringToFloatConverter().StringToFloat( |
353 | str.data(), static_cast<int>(len), &processed_characters_count); |
354 | return processed_characters_count > 0; |
355 | } |
356 | |
357 | bool safe_strtod(StringPiece str, double* value) { |
358 | int processed_characters_count = -1; |
359 | auto len = str.size(); |
360 | |
361 | // If string length exceeds buffer size or int max, fail. |
362 | if (len >= kFastToBufferSize) return false; |
363 | if (len > std::numeric_limits<int>::max()) return false; |
364 | |
365 | *value = StringToFloatConverter().StringToDouble( |
366 | str.data(), static_cast<int>(len), &processed_characters_count); |
367 | return processed_characters_count > 0; |
368 | } |
369 | |
370 | size_t FloatToBuffer(float value, char* buffer) { |
371 | // FLT_DIG is 6 for IEEE-754 floats, which are used on almost all |
372 | // platforms these days. Just in case some system exists where FLT_DIG |
373 | // is significantly larger -- and risks overflowing our buffer -- we have |
374 | // this assert. |
375 | static_assert(FLT_DIG < 10, "FLT_DIG is too big" ); |
376 | |
377 | if (std::isnan(value)) { |
378 | int snprintf_result = snprintf(buffer, kFastToBufferSize, "%snan" , |
379 | std::signbit(value) ? "-" : "" ); |
380 | // Paranoid check to ensure we don't overflow the buffer. |
381 | DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); |
382 | return snprintf_result; |
383 | } |
384 | |
385 | int snprintf_result = |
386 | snprintf(buffer, kFastToBufferSize, "%.*g" , FLT_DIG, value); |
387 | |
388 | // The snprintf should never overflow because the buffer is significantly |
389 | // larger than the precision we asked for. |
390 | DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); |
391 | |
392 | float parsed_value; |
393 | if (!safe_strtof(buffer, &parsed_value) || parsed_value != value) { |
394 | snprintf_result = |
395 | snprintf(buffer, kFastToBufferSize, "%.*g" , FLT_DIG + 3, value); |
396 | |
397 | // Should never overflow; see above. |
398 | DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); |
399 | } |
400 | return snprintf_result; |
401 | } |
402 | |
403 | std::string FpToString(Fprint fp) { |
404 | char buf[17]; |
405 | snprintf(buf, sizeof(buf), "%016llx" , static_cast<long long>(fp)); |
406 | return std::string(buf); |
407 | } |
408 | |
409 | bool StringToFp(const std::string& s, Fprint* fp) { |
410 | char junk; |
411 | uint64_t result; |
412 | if (sscanf(s.c_str(), "%" SCNx64 "%c" , &result, &junk) == 1) { |
413 | *fp = result; |
414 | return true; |
415 | } else { |
416 | return false; |
417 | } |
418 | } |
419 | |
420 | StringPiece Uint64ToHexString(uint64_t v, char* buf) { |
421 | static const char* hexdigits = "0123456789abcdef" ; |
422 | const int num_byte = 16; |
423 | buf[num_byte] = '\0'; |
424 | for (int i = num_byte - 1; i >= 0; i--) { |
425 | buf[i] = hexdigits[v & 0xf]; |
426 | v >>= 4; |
427 | } |
428 | return StringPiece(buf, num_byte); |
429 | } |
430 | |
431 | bool HexStringToUint64(const StringPiece& s, uint64_t* result) { |
432 | uint64_t v = 0; |
433 | if (s.empty()) { |
434 | return false; |
435 | } |
436 | for (size_t i = 0; i < s.size(); i++) { |
437 | char c = s[i]; |
438 | if (c >= '0' && c <= '9') { |
439 | v = (v << 4) + (c - '0'); |
440 | } else if (c >= 'a' && c <= 'f') { |
441 | v = (v << 4) + 10 + (c - 'a'); |
442 | } else if (c >= 'A' && c <= 'F') { |
443 | v = (v << 4) + 10 + (c - 'A'); |
444 | } else { |
445 | return false; |
446 | } |
447 | } |
448 | *result = v; |
449 | return true; |
450 | } |
451 | |
452 | std::string HumanReadableNum(int64_t value) { |
453 | std::string s; |
454 | if (value < 0) { |
455 | s += "-" ; |
456 | value = -value; |
457 | } |
458 | if (value < 1000) { |
459 | Appendf(&s, "%lld" , static_cast<long long>(value)); |
460 | } else if (value >= static_cast<int64_t>(1e15)) { |
461 | // Number bigger than 1E15; use that notation. |
462 | Appendf(&s, "%0.3G" , static_cast<double>(value)); |
463 | } else { |
464 | static const char units[] = "kMBT" ; |
465 | const char* unit = units; |
466 | while (value >= static_cast<int64_t>(1000000)) { |
467 | value /= static_cast<int64_t>(1000); |
468 | ++unit; |
469 | CHECK(unit < units + TF_ARRAYSIZE(units)); |
470 | } |
471 | Appendf(&s, "%.2f%c" , value / 1000.0, *unit); |
472 | } |
473 | return s; |
474 | } |
475 | |
476 | std::string HumanReadableNumBytes(int64_t num_bytes) { |
477 | if (num_bytes == kint64min) { |
478 | // Special case for number with not representable negation. |
479 | return "-8E" ; |
480 | } |
481 | |
482 | const char* neg_str = (num_bytes < 0) ? "-" : "" ; |
483 | if (num_bytes < 0) { |
484 | num_bytes = -num_bytes; |
485 | } |
486 | |
487 | // Special case for bytes. |
488 | if (num_bytes < 1024) { |
489 | // No fractions for bytes. |
490 | char buf[8]; // Longest possible string is '-XXXXB' |
491 | snprintf(buf, sizeof(buf), "%s%lldB" , neg_str, |
492 | static_cast<long long>(num_bytes)); |
493 | return std::string(buf); |
494 | } |
495 | |
496 | static const char units[] = "KMGTPE" ; // int64 only goes up to E. |
497 | const char* unit = units; |
498 | while (num_bytes >= static_cast<int64_t>(1024) * 1024) { |
499 | num_bytes /= 1024; |
500 | ++unit; |
501 | CHECK(unit < units + TF_ARRAYSIZE(units)); |
502 | } |
503 | |
504 | // We use SI prefixes. |
505 | char buf[16]; |
506 | snprintf(buf, sizeof(buf), ((*unit == 'K') ? "%s%.1f%ciB" : "%s%.2f%ciB" ), |
507 | neg_str, num_bytes / 1024.0, *unit); |
508 | return std::string(buf); |
509 | } |
510 | |
511 | std::string HumanReadableElapsedTime(double seconds) { |
512 | std::string human_readable; |
513 | |
514 | if (seconds < 0) { |
515 | human_readable = "-" ; |
516 | seconds = -seconds; |
517 | } |
518 | |
519 | // Start with us and keep going up to years. |
520 | // The comparisons must account for rounding to prevent the format breaking |
521 | // the tested condition and returning, e.g., "1e+03 us" instead of "1 ms". |
522 | const double microseconds = seconds * 1.0e6; |
523 | if (microseconds < 999.5) { |
524 | strings::Appendf(&human_readable, "%0.3g us" , microseconds); |
525 | return human_readable; |
526 | } |
527 | double milliseconds = seconds * 1e3; |
528 | if (milliseconds >= .995 && milliseconds < 1) { |
529 | // Round half to even in Appendf would convert this to 0.999 ms. |
530 | milliseconds = 1.0; |
531 | } |
532 | if (milliseconds < 999.5) { |
533 | strings::Appendf(&human_readable, "%0.3g ms" , milliseconds); |
534 | return human_readable; |
535 | } |
536 | if (seconds < 60.0) { |
537 | strings::Appendf(&human_readable, "%0.3g s" , seconds); |
538 | return human_readable; |
539 | } |
540 | seconds /= 60.0; |
541 | if (seconds < 60.0) { |
542 | strings::Appendf(&human_readable, "%0.3g min" , seconds); |
543 | return human_readable; |
544 | } |
545 | seconds /= 60.0; |
546 | if (seconds < 24.0) { |
547 | strings::Appendf(&human_readable, "%0.3g h" , seconds); |
548 | return human_readable; |
549 | } |
550 | seconds /= 24.0; |
551 | if (seconds < 30.0) { |
552 | strings::Appendf(&human_readable, "%0.3g days" , seconds); |
553 | return human_readable; |
554 | } |
555 | if (seconds < 365.2425) { |
556 | strings::Appendf(&human_readable, "%0.3g months" , seconds / 30.436875); |
557 | return human_readable; |
558 | } |
559 | seconds /= 365.2425; |
560 | strings::Appendf(&human_readable, "%0.3g years" , seconds); |
561 | return human_readable; |
562 | } |
563 | |
564 | } // namespace strings |
565 | } // namespace tsl |
566 | |