1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2Licensed under the Apache License, Version 2.0 (the "License");
3you may not use this file except in compliance with the License.
4You may obtain a copy of the License at
5
6 http://www.apache.org/licenses/LICENSE-2.0
7
8Unless required by applicable law or agreed to in writing, software
9distributed under the License is distributed on an "AS IS" BASIS,
10WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11See the License for the specific language governing permissions and
12limitations under the License.
13==============================================================================*/
14
15#include "tensorflow/tsl/platform/numbers.h"
16
17#include <ctype.h>
18#include <float.h>
19#include <stdio.h>
20#include <stdlib.h>
21
22#include <algorithm>
23#include <cinttypes>
24#include <cmath>
25#include <cstdint>
26#include <locale>
27#include <unordered_map>
28
29#include "double-conversion/double-conversion.h"
30#include "tensorflow/tsl/platform/str_util.h"
31#include "tensorflow/tsl/platform/logging.h"
32#include "tensorflow/tsl/platform/macros.h"
33#include "tensorflow/tsl/platform/stringprintf.h"
34#include "tensorflow/tsl/platform/types.h"
35
36namespace tsl {
37
38namespace {
39
40template <typename T>
41const std::unordered_map<std::string, T>* GetSpecialNumsSingleton() {
42 static const std::unordered_map<std::string, T>* special_nums =
43 CHECK_NOTNULL((new const std::unordered_map<std::string, T>{
44 {"inf", std::numeric_limits<T>::infinity()},
45 {"+inf", std::numeric_limits<T>::infinity()},
46 {"-inf", -std::numeric_limits<T>::infinity()},
47 {"infinity", std::numeric_limits<T>::infinity()},
48 {"+infinity", std::numeric_limits<T>::infinity()},
49 {"-infinity", -std::numeric_limits<T>::infinity()},
50 {"nan", std::numeric_limits<T>::quiet_NaN()},
51 {"+nan", std::numeric_limits<T>::quiet_NaN()},
52 {"-nan", -std::numeric_limits<T>::quiet_NaN()},
53 }));
54 return special_nums;
55}
56
57template <typename T>
58T locale_independent_strtonum(const char* str, const char** endptr) {
59 auto special_nums = GetSpecialNumsSingleton<T>();
60 std::stringstream s(str);
61
62 // Check if str is one of the special numbers.
63 std::string special_num_str;
64 s >> special_num_str;
65
66 for (size_t i = 0; i < special_num_str.length(); ++i) {
67 special_num_str[i] =
68 std::tolower(special_num_str[i], std::locale::classic());
69 }
70
71 auto entry = special_nums->find(special_num_str);
72 if (entry != special_nums->end()) {
73 *endptr = str + (s.eof() ? static_cast<std::iostream::pos_type>(strlen(str))
74 : s.tellg());
75 return entry->second;
76 } else {
77 // Perhaps it's a hex number
78 if (special_num_str.compare(0, 2, "0x") == 0 ||
79 special_num_str.compare(0, 3, "-0x") == 0) {
80 return strtol(str, const_cast<char**>(endptr), 16);
81 }
82 }
83 // Reset the stream
84 s.str(str);
85 s.clear();
86 // Use the "C" locale
87 s.imbue(std::locale::classic());
88
89 T result;
90 s >> result;
91
92 // Set to result to what strto{f,d} functions would have returned. If the
93 // number was outside the range, the stringstream sets the fail flag, but
94 // returns the +/-max() value, whereas strto{f,d} functions return +/-INF.
95 if (s.fail()) {
96 if (result == std::numeric_limits<T>::max() ||
97 result == std::numeric_limits<T>::infinity()) {
98 result = std::numeric_limits<T>::infinity();
99 s.clear(s.rdstate() & ~std::ios::failbit);
100 } else if (result == -std::numeric_limits<T>::max() ||
101 result == -std::numeric_limits<T>::infinity()) {
102 result = -std::numeric_limits<T>::infinity();
103 s.clear(s.rdstate() & ~std::ios::failbit);
104 }
105 }
106
107 if (endptr) {
108 *endptr =
109 str +
110 (s.fail() ? static_cast<std::iostream::pos_type>(0)
111 : (s.eof() ? static_cast<std::iostream::pos_type>(strlen(str))
112 : s.tellg()));
113 }
114 return result;
115}
116
117static inline const double_conversion::StringToDoubleConverter&
118StringToFloatConverter() {
119 static const double_conversion::StringToDoubleConverter converter(
120 double_conversion::StringToDoubleConverter::ALLOW_LEADING_SPACES |
121 double_conversion::StringToDoubleConverter::ALLOW_HEX |
122 double_conversion::StringToDoubleConverter::ALLOW_TRAILING_SPACES |
123 double_conversion::StringToDoubleConverter::ALLOW_CASE_INSENSIBILITY,
124 0., 0., "inf", "nan");
125 return converter;
126}
127
128} // namespace
129
130namespace strings {
131
132size_t FastInt32ToBufferLeft(int32_t i, char* buffer) {
133 uint32_t u = i;
134 size_t length = 0;
135 if (i < 0) {
136 *buffer++ = '-';
137 ++length;
138 // We need to do the negation in modular (i.e., "unsigned")
139 // arithmetic; MSVC++ apparently warns for plain "-u", so
140 // we write the equivalent expression "0 - u" instead.
141 u = 0 - u;
142 }
143 length += FastUInt32ToBufferLeft(u, buffer);
144 return length;
145}
146
147size_t FastUInt32ToBufferLeft(uint32_t i, char* buffer) {
148 char* start = buffer;
149 do {
150 *buffer++ = ((i % 10) + '0');
151 i /= 10;
152 } while (i > 0);
153 *buffer = 0;
154 std::reverse(start, buffer);
155 return buffer - start;
156}
157
158size_t FastInt64ToBufferLeft(int64_t i, char* buffer) {
159 uint64_t u = i;
160 size_t length = 0;
161 if (i < 0) {
162 *buffer++ = '-';
163 ++length;
164 u = 0 - u;
165 }
166 length += FastUInt64ToBufferLeft(u, buffer);
167 return length;
168}
169
170size_t FastUInt64ToBufferLeft(uint64_t i, char* buffer) {
171 char* start = buffer;
172 do {
173 *buffer++ = ((i % 10) + '0');
174 i /= 10;
175 } while (i > 0);
176 *buffer = 0;
177 std::reverse(start, buffer);
178 return buffer - start;
179}
180
181static const double kDoublePrecisionCheckMax = DBL_MAX / 1.000000000000001;
182
183size_t DoubleToBuffer(double value, char* buffer) {
184 // DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all
185 // platforms these days. Just in case some system exists where DBL_DIG
186 // is significantly larger -- and risks overflowing our buffer -- we have
187 // this assert.
188 static_assert(DBL_DIG < 20, "DBL_DIG is too big");
189
190 if (std::isnan(value)) {
191 int snprintf_result = snprintf(buffer, kFastToBufferSize, "%snan",
192 std::signbit(value) ? "-" : "");
193 // Paranoid check to ensure we don't overflow the buffer.
194 DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
195 return snprintf_result;
196 }
197
198 if (std::abs(value) <= kDoublePrecisionCheckMax) {
199 int snprintf_result =
200 snprintf(buffer, kFastToBufferSize, "%.*g", DBL_DIG, value);
201
202 // The snprintf should never overflow because the buffer is significantly
203 // larger than the precision we asked for.
204 DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
205
206 if (locale_independent_strtonum<double>(buffer, nullptr) == value) {
207 // Round-tripping the string to double works; we're done.
208 return snprintf_result;
209 }
210 // else: full precision formatting needed. Fall through.
211 }
212
213 int snprintf_result =
214 snprintf(buffer, kFastToBufferSize, "%.*g", DBL_DIG + 2, value);
215
216 // Should never overflow; see above.
217 DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
218
219 return snprintf_result;
220}
221
222namespace {
223char SafeFirstChar(StringPiece str) {
224 if (str.empty()) return '\0';
225 return str[0];
226}
227void SkipSpaces(StringPiece* str) {
228 while (isspace(SafeFirstChar(*str))) str->remove_prefix(1);
229}
230} // namespace
231
232bool safe_strto64(StringPiece str, int64_t* value) {
233 SkipSpaces(&str);
234
235 int64_t vlimit = kint64max;
236 int sign = 1;
237 if (absl::ConsumePrefix(&str, "-")) {
238 sign = -1;
239 // Different limit for positive and negative integers.
240 vlimit = kint64min;
241 }
242
243 if (!isdigit(SafeFirstChar(str))) return false;
244
245 int64_t result = 0;
246 if (sign == 1) {
247 do {
248 int digit = SafeFirstChar(str) - '0';
249 if ((vlimit - digit) / 10 < result) {
250 return false;
251 }
252 result = result * 10 + digit;
253 str.remove_prefix(1);
254 } while (isdigit(SafeFirstChar(str)));
255 } else {
256 do {
257 int digit = SafeFirstChar(str) - '0';
258 if ((vlimit + digit) / 10 > result) {
259 return false;
260 }
261 result = result * 10 - digit;
262 str.remove_prefix(1);
263 } while (isdigit(SafeFirstChar(str)));
264 }
265
266 SkipSpaces(&str);
267 if (!str.empty()) return false;
268
269 *value = result;
270 return true;
271}
272
273bool safe_strtou64(StringPiece str, uint64_t* value) {
274 SkipSpaces(&str);
275 if (!isdigit(SafeFirstChar(str))) return false;
276
277 uint64_t result = 0;
278 do {
279 int digit = SafeFirstChar(str) - '0';
280 if ((kuint64max - digit) / 10 < result) {
281 return false;
282 }
283 result = result * 10 + digit;
284 str.remove_prefix(1);
285 } while (isdigit(SafeFirstChar(str)));
286
287 SkipSpaces(&str);
288 if (!str.empty()) return false;
289
290 *value = result;
291 return true;
292}
293
294bool safe_strto32(StringPiece str, int32_t* value) {
295 SkipSpaces(&str);
296
297 int64_t vmax = kint32max;
298 int sign = 1;
299 if (absl::ConsumePrefix(&str, "-")) {
300 sign = -1;
301 // Different max for positive and negative integers.
302 ++vmax;
303 }
304
305 if (!isdigit(SafeFirstChar(str))) return false;
306
307 int64_t result = 0;
308 do {
309 result = result * 10 + SafeFirstChar(str) - '0';
310 if (result > vmax) {
311 return false;
312 }
313 str.remove_prefix(1);
314 } while (isdigit(SafeFirstChar(str)));
315
316 SkipSpaces(&str);
317
318 if (!str.empty()) return false;
319
320 *value = static_cast<int32_t>(result * sign);
321 return true;
322}
323
324bool safe_strtou32(StringPiece str, uint32_t* value) {
325 SkipSpaces(&str);
326 if (!isdigit(SafeFirstChar(str))) return false;
327
328 int64_t result = 0;
329 do {
330 result = result * 10 + SafeFirstChar(str) - '0';
331 if (result > kuint32max) {
332 return false;
333 }
334 str.remove_prefix(1);
335 } while (isdigit(SafeFirstChar(str)));
336
337 SkipSpaces(&str);
338 if (!str.empty()) return false;
339
340 *value = static_cast<uint32_t>(result);
341 return true;
342}
343
344bool safe_strtof(StringPiece str, float* value) {
345 int processed_characters_count = -1;
346 auto len = str.size();
347
348 // If string length exceeds buffer size or int max, fail.
349 if (len >= kFastToBufferSize) return false;
350 if (len > std::numeric_limits<int>::max()) return false;
351
352 *value = StringToFloatConverter().StringToFloat(
353 str.data(), static_cast<int>(len), &processed_characters_count);
354 return processed_characters_count > 0;
355}
356
357bool safe_strtod(StringPiece str, double* value) {
358 int processed_characters_count = -1;
359 auto len = str.size();
360
361 // If string length exceeds buffer size or int max, fail.
362 if (len >= kFastToBufferSize) return false;
363 if (len > std::numeric_limits<int>::max()) return false;
364
365 *value = StringToFloatConverter().StringToDouble(
366 str.data(), static_cast<int>(len), &processed_characters_count);
367 return processed_characters_count > 0;
368}
369
370size_t FloatToBuffer(float value, char* buffer) {
371 // FLT_DIG is 6 for IEEE-754 floats, which are used on almost all
372 // platforms these days. Just in case some system exists where FLT_DIG
373 // is significantly larger -- and risks overflowing our buffer -- we have
374 // this assert.
375 static_assert(FLT_DIG < 10, "FLT_DIG is too big");
376
377 if (std::isnan(value)) {
378 int snprintf_result = snprintf(buffer, kFastToBufferSize, "%snan",
379 std::signbit(value) ? "-" : "");
380 // Paranoid check to ensure we don't overflow the buffer.
381 DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
382 return snprintf_result;
383 }
384
385 int snprintf_result =
386 snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG, value);
387
388 // The snprintf should never overflow because the buffer is significantly
389 // larger than the precision we asked for.
390 DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
391
392 float parsed_value;
393 if (!safe_strtof(buffer, &parsed_value) || parsed_value != value) {
394 snprintf_result =
395 snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG + 3, value);
396
397 // Should never overflow; see above.
398 DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
399 }
400 return snprintf_result;
401}
402
403std::string FpToString(Fprint fp) {
404 char buf[17];
405 snprintf(buf, sizeof(buf), "%016llx", static_cast<long long>(fp));
406 return std::string(buf);
407}
408
409bool StringToFp(const std::string& s, Fprint* fp) {
410 char junk;
411 uint64_t result;
412 if (sscanf(s.c_str(), "%" SCNx64 "%c", &result, &junk) == 1) {
413 *fp = result;
414 return true;
415 } else {
416 return false;
417 }
418}
419
420StringPiece Uint64ToHexString(uint64_t v, char* buf) {
421 static const char* hexdigits = "0123456789abcdef";
422 const int num_byte = 16;
423 buf[num_byte] = '\0';
424 for (int i = num_byte - 1; i >= 0; i--) {
425 buf[i] = hexdigits[v & 0xf];
426 v >>= 4;
427 }
428 return StringPiece(buf, num_byte);
429}
430
431bool HexStringToUint64(const StringPiece& s, uint64_t* result) {
432 uint64_t v = 0;
433 if (s.empty()) {
434 return false;
435 }
436 for (size_t i = 0; i < s.size(); i++) {
437 char c = s[i];
438 if (c >= '0' && c <= '9') {
439 v = (v << 4) + (c - '0');
440 } else if (c >= 'a' && c <= 'f') {
441 v = (v << 4) + 10 + (c - 'a');
442 } else if (c >= 'A' && c <= 'F') {
443 v = (v << 4) + 10 + (c - 'A');
444 } else {
445 return false;
446 }
447 }
448 *result = v;
449 return true;
450}
451
452std::string HumanReadableNum(int64_t value) {
453 std::string s;
454 if (value < 0) {
455 s += "-";
456 value = -value;
457 }
458 if (value < 1000) {
459 Appendf(&s, "%lld", static_cast<long long>(value));
460 } else if (value >= static_cast<int64_t>(1e15)) {
461 // Number bigger than 1E15; use that notation.
462 Appendf(&s, "%0.3G", static_cast<double>(value));
463 } else {
464 static const char units[] = "kMBT";
465 const char* unit = units;
466 while (value >= static_cast<int64_t>(1000000)) {
467 value /= static_cast<int64_t>(1000);
468 ++unit;
469 CHECK(unit < units + TF_ARRAYSIZE(units));
470 }
471 Appendf(&s, "%.2f%c", value / 1000.0, *unit);
472 }
473 return s;
474}
475
476std::string HumanReadableNumBytes(int64_t num_bytes) {
477 if (num_bytes == kint64min) {
478 // Special case for number with not representable negation.
479 return "-8E";
480 }
481
482 const char* neg_str = (num_bytes < 0) ? "-" : "";
483 if (num_bytes < 0) {
484 num_bytes = -num_bytes;
485 }
486
487 // Special case for bytes.
488 if (num_bytes < 1024) {
489 // No fractions for bytes.
490 char buf[8]; // Longest possible string is '-XXXXB'
491 snprintf(buf, sizeof(buf), "%s%lldB", neg_str,
492 static_cast<long long>(num_bytes));
493 return std::string(buf);
494 }
495
496 static const char units[] = "KMGTPE"; // int64 only goes up to E.
497 const char* unit = units;
498 while (num_bytes >= static_cast<int64_t>(1024) * 1024) {
499 num_bytes /= 1024;
500 ++unit;
501 CHECK(unit < units + TF_ARRAYSIZE(units));
502 }
503
504 // We use SI prefixes.
505 char buf[16];
506 snprintf(buf, sizeof(buf), ((*unit == 'K') ? "%s%.1f%ciB" : "%s%.2f%ciB"),
507 neg_str, num_bytes / 1024.0, *unit);
508 return std::string(buf);
509}
510
511std::string HumanReadableElapsedTime(double seconds) {
512 std::string human_readable;
513
514 if (seconds < 0) {
515 human_readable = "-";
516 seconds = -seconds;
517 }
518
519 // Start with us and keep going up to years.
520 // The comparisons must account for rounding to prevent the format breaking
521 // the tested condition and returning, e.g., "1e+03 us" instead of "1 ms".
522 const double microseconds = seconds * 1.0e6;
523 if (microseconds < 999.5) {
524 strings::Appendf(&human_readable, "%0.3g us", microseconds);
525 return human_readable;
526 }
527 double milliseconds = seconds * 1e3;
528 if (milliseconds >= .995 && milliseconds < 1) {
529 // Round half to even in Appendf would convert this to 0.999 ms.
530 milliseconds = 1.0;
531 }
532 if (milliseconds < 999.5) {
533 strings::Appendf(&human_readable, "%0.3g ms", milliseconds);
534 return human_readable;
535 }
536 if (seconds < 60.0) {
537 strings::Appendf(&human_readable, "%0.3g s", seconds);
538 return human_readable;
539 }
540 seconds /= 60.0;
541 if (seconds < 60.0) {
542 strings::Appendf(&human_readable, "%0.3g min", seconds);
543 return human_readable;
544 }
545 seconds /= 60.0;
546 if (seconds < 24.0) {
547 strings::Appendf(&human_readable, "%0.3g h", seconds);
548 return human_readable;
549 }
550 seconds /= 24.0;
551 if (seconds < 30.0) {
552 strings::Appendf(&human_readable, "%0.3g days", seconds);
553 return human_readable;
554 }
555 if (seconds < 365.2425) {
556 strings::Appendf(&human_readable, "%0.3g months", seconds / 30.436875);
557 return human_readable;
558 }
559 seconds /= 365.2425;
560 strings::Appendf(&human_readable, "%0.3g years", seconds);
561 return human_readable;
562}
563
564} // namespace strings
565} // namespace tsl
566