1#pragma once
2
3#include <c10/macros/Macros.h>
4#include <c10/util/C++17.h>
5#include <c10/util/reverse_iterator.h>
6#include <algorithm>
7#include <cstring>
8#include <limits>
9#include <stdexcept>
10#include <string>
11
12#if __cpp_lib_string_view
13#include <string_view>
14#define C10_HAS_STD_STRING_VIEW() 1
15#define C10_HAS_STD_EXPERIMENTAL_STRING_VIEW() 0
16#elif defined(__has_include)
17#if __has_include(<experimental/string_view>)
18// libc++ 7.0 has experimental/string_view but it's just a #error
19#if !defined(_LIBCPP_VERSION) || (_LIBCPP_VERSION < 7000)
20#include <experimental/string_view>
21#endif
22#if __cpp_lib_experimental_string_view
23#define C10_HAS_STD_STRING_VIEW() 0
24#define C10_HAS_STD_EXPERIMENTAL_STRING_VIEW() 1
25#endif
26#endif
27#endif
28
29#ifndef C10_HAS_STD_STRING_VIEW
30#define C10_HAS_STD_STRING_VIEW() 0
31#endif
32#ifndef C10_HAS_STD_EXPERIMENTAL_STRING_VIEW
33#define C10_HAS_STD_EXPERIMENTAL_STRING_VIEW() 0
34#endif
35
36C10_CLANG_DIAGNOSTIC_PUSH()
37#if C10_CLANG_HAS_WARNING("-Wdeprecated")
38C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated")
39#endif
40
41namespace c10 {
42
43/**
44 * Reimplementation of std::string_view for C++11.
45 * Implemented following the interface definition in
46 * https://en.cppreference.com/w/cpp/string/basic_string_view
47 * See there for the API documentation.
48 *
49 * Difference: We don't have a Traits template parameter because
50 * std::char_traits isn't constexpr and we'd have to reimplement
51 * std::char_traits if we wanted to use it with our constexpr basic_string_view.
52 */
53template <class CharT>
54class basic_string_view final {
55 public:
56 using value_type = CharT;
57 using pointer = CharT*;
58 using const_pointer = const CharT*;
59 using reference = CharT&;
60 using const_reference = const CharT&;
61 using const_iterator = const CharT*;
62 using iterator = const_iterator;
63 using const_reverse_iterator = c10::reverse_iterator<const_iterator>;
64 using reverse_iterator = const_reverse_iterator;
65 using size_type = std::size_t;
66 using difference_type = std::ptrdiff_t;
67
68 static constexpr size_type npos = size_type(-1);
69
70 constexpr basic_string_view() noexcept : begin_(nullptr), size_(0) {}
71
72 explicit constexpr basic_string_view(const_pointer str, size_type count)
73 : begin_(str), size_(count) {}
74
75 /* implicit */ constexpr basic_string_view(const_pointer str)
76 : basic_string_view(str, strlen_(str)) {}
77
78 /* implicit */ basic_string_view(const ::std::basic_string<CharT>& str)
79 : basic_string_view(str.data(), str.size()) {}
80
81 constexpr basic_string_view(const basic_string_view&) noexcept = default;
82
83 constexpr basic_string_view& operator=(
84 const basic_string_view& rhs) noexcept {
85 begin_ = rhs.begin_;
86 size_ = rhs.size_;
87 return *this;
88 }
89
90 explicit operator ::std::basic_string<CharT>() const {
91 return ::std::basic_string<CharT>(data(), size());
92 }
93
94 constexpr const_iterator begin() const noexcept {
95 return cbegin();
96 }
97
98 constexpr const_iterator cbegin() const noexcept {
99 return begin_;
100 }
101
102 constexpr const_iterator end() const noexcept {
103 return cend();
104 }
105
106 constexpr const_iterator cend() const noexcept {
107 return begin_ + size_;
108 }
109
110 constexpr const_reverse_iterator rbegin() const noexcept {
111 return crbegin();
112 }
113
114 constexpr const_reverse_iterator crbegin() const noexcept {
115 return const_reverse_iterator(this->end());
116 }
117
118 constexpr const_reverse_iterator rend() const noexcept {
119 return crend();
120 }
121
122 constexpr const_reverse_iterator crend() const noexcept {
123 return const_reverse_iterator(this->begin());
124 }
125
126 friend constexpr const_iterator begin(basic_string_view sv) noexcept {
127 return sv.begin();
128 }
129
130 friend constexpr const_iterator end(basic_string_view sv) noexcept {
131 return sv.end();
132 }
133
134 constexpr const_reference operator[](size_type pos) const {
135 // TODO: split out
136 return at_(pos);
137 }
138
139 constexpr const_reference at(size_type pos) const {
140#if !defined( \
141 __CUDA_ARCH__) // CUDA doesn't like std::out_of_range in device code
142 return C10_UNLIKELY(pos >= size_)
143 ? (throw std::out_of_range(
144 "string_view::operator[] or string_view::at() out of range. Index: " +
145 c10::guts::to_string(pos) +
146 ", size: " + c10::guts::to_string(size())),
147 at_(0))
148 : at_(pos);
149#else
150 return at_(pos);
151#endif
152 }
153
154 constexpr const_reference front() const {
155 return *begin_;
156 }
157
158 constexpr const_reference back() const {
159 return *(begin_ + size_ - 1);
160 }
161
162 constexpr const_pointer data() const noexcept {
163 return begin_;
164 }
165
166 constexpr size_type size() const noexcept {
167 return size_;
168 }
169
170 constexpr size_type length() const noexcept {
171 return size();
172 }
173
174 constexpr size_type max_size() const noexcept {
175 return std::numeric_limits<difference_type>::max();
176 }
177
178 C10_NODISCARD constexpr bool empty() const noexcept {
179 return size() == 0;
180 }
181
182 constexpr void remove_prefix(size_type n) {
183 if (n > size()) {
184 throw std::out_of_range(
185 "basic_string_view::remove_prefix: out of range. PrefixLength: " +
186 c10::guts::to_string(n) + ", size: " + c10::guts::to_string(size()));
187 }
188 begin_ += n;
189 size_ -= n;
190 }
191
192 constexpr void remove_suffix(size_type n) {
193 if (n > size()) {
194 throw std::out_of_range(
195 "basic_string_view::remove_suffix: out of range. SuffixLength: " +
196 c10::guts::to_string(n) + ", size: " + c10::guts::to_string(size()));
197 }
198 size_ -= n;
199 }
200
201 constexpr void swap(basic_string_view& sv) noexcept {
202 auto tmp = *this;
203 *this = sv;
204 sv = tmp;
205 }
206
207 size_type copy(pointer dest, size_type count, size_type pos = 0) const {
208 if (pos > size_) {
209 throw std::out_of_range(
210 "basic_string_view::copy: out of range. Index: " +
211 c10::guts::to_string(pos) +
212 ", size: " + c10::guts::to_string(size()));
213 }
214 size_type copy_length = guts::min(count, size_ - pos);
215 for (auto iter = begin() + pos, end = iter + copy_length; iter != end;) {
216 *(dest++) = *(iter++);
217 }
218 return copy_length;
219 }
220
221 constexpr basic_string_view substr(size_type pos = 0, size_type count = npos)
222 const {
223#if !defined( \
224 __CUDA_ARCH__) // CUDA doesn't like std::out_of_range in device code
225 return (pos > size_)
226 ? (throw std::out_of_range(
227 "basic_string_view::substr parameter out of bounds. Index: " +
228 c10::guts::to_string(pos) +
229 ", size: " + c10::guts::to_string(size())),
230 substr_())
231 : substr_(pos, count);
232#else
233 return substr_(pos, count);
234#endif
235 }
236
237 constexpr int compare(basic_string_view rhs) const noexcept {
238#if __cpp_constexpr >= 201304
239 // if we are in C++14, write it iteratively. This is faster.
240 for (size_t i = 0, end = guts::min(size(), rhs.size()); i < end; ++i) {
241 if (at_(i) < rhs.at_(i)) {
242 return -1;
243 } else if (at_(i) > rhs.at_(i)) {
244 return 1;
245 }
246 }
247 if (size() < rhs.size()) {
248 return -1;
249 } else if (size() > rhs.size()) {
250 return 1;
251 }
252 return 0;
253#else
254 // if we are in C++11, we need to do it recursively because of constexpr
255 // restrictions.
256 return (size() == 0 && rhs.size() == 0) ? 0
257 : (size() == 0) ? -1
258 : (rhs.size() == 0) ? 1
259 : (front() < rhs.front()) ? -1
260 : (front() > rhs.front()) ? 1
261 : substr_(1).compare(rhs.substr_(1));
262#endif
263 }
264
265 constexpr int compare(size_type pos1, size_type count1, basic_string_view v)
266 const {
267 return substr(pos1, count1).compare(v);
268 }
269
270 constexpr int compare(
271 size_type pos1,
272 size_type count1,
273 basic_string_view v,
274 size_type pos2,
275 size_type count2) const {
276 return substr(pos1, count1).compare(v.substr(pos2, count2));
277 }
278
279 constexpr int compare(const_pointer s) const {
280 return compare(basic_string_view(s));
281 }
282
283 constexpr int compare(size_type pos1, size_type count1, const_pointer s)
284 const {
285 return substr(pos1, count1).compare(basic_string_view(s));
286 }
287
288 constexpr int compare(
289 size_type pos1,
290 size_type count1,
291 const_pointer s,
292 size_type count2) const {
293 return substr(pos1, count1).compare(basic_string_view(s, count2));
294 }
295
296 friend constexpr bool operator==(
297 basic_string_view lhs,
298 basic_string_view rhs) noexcept {
299 return lhs.equals_(rhs);
300 }
301
302 friend constexpr bool operator!=(
303 basic_string_view lhs,
304 basic_string_view rhs) noexcept {
305 return !(lhs == rhs);
306 }
307
308 friend constexpr bool operator<(
309 basic_string_view lhs,
310 basic_string_view rhs) noexcept {
311 return lhs.compare(rhs) < 0;
312 }
313
314 friend constexpr bool operator>=(
315 basic_string_view lhs,
316 basic_string_view rhs) noexcept {
317 return !(lhs < rhs);
318 }
319
320 friend constexpr bool operator>(
321 basic_string_view lhs,
322 basic_string_view rhs) noexcept {
323 return rhs < lhs;
324 }
325
326 friend constexpr bool operator<=(
327 basic_string_view lhs,
328 basic_string_view rhs) noexcept {
329 return !(lhs > rhs);
330 }
331
332 constexpr bool starts_with(basic_string_view prefix) const noexcept {
333 return (prefix.size() > size()) ? false
334 : prefix.equals_(substr_(0, prefix.size()));
335 }
336
337 constexpr bool starts_with(CharT prefix) const noexcept {
338 return !empty() && prefix == front();
339 }
340
341 constexpr bool starts_with(const_pointer prefix) const {
342 return starts_with(basic_string_view(prefix));
343 }
344
345 constexpr bool ends_with(basic_string_view suffix) const noexcept {
346 return (suffix.size() > size())
347 ? false
348 : suffix.equals_(substr_(size() - suffix.size(), suffix.size()));
349 }
350
351 constexpr bool ends_with(CharT suffix) const noexcept {
352 return !empty() && suffix == back();
353 }
354
355 constexpr bool ends_with(const_pointer suffix) const {
356 return ends_with(basic_string_view(suffix));
357 }
358
359 constexpr size_type find(basic_string_view v, size_type pos = 0)
360 const noexcept {
361#if __cpp_constexpr >= 201304
362 // if we are in C++14, write it iteratively. This is faster.
363 if (v.size() == 0) {
364 return pos <= size() ? pos : npos;
365 }
366
367 if (pos + v.size() <= size()) {
368 for (size_type cur = pos, end = size() - v.size(); cur <= end; ++cur) {
369 if (v.at_(0) == at_(cur) &&
370 v.substr_(1).equals_(substr_(cur + 1, v.size() - 1))) {
371 return cur;
372 }
373 }
374 }
375 return npos;
376#else
377 // if we are in C++11, we need to do it recursively because of constexpr
378 // restrictions.
379 return (v.size() == 0) ? (pos <= size() ? pos : npos)
380 : (pos + v.size() > size()) ? npos
381 : (v.at_(0) == at_(pos) &&
382 v.substr_(1).equals_(substr_(pos + 1, v.size() - 1)))
383 ? pos
384 : find(v, pos + 1);
385#endif
386 }
387
388 constexpr size_type find(CharT ch, size_type pos = 0) const noexcept {
389 return find_first_if_(pos, charIsEqual_{ch});
390 }
391
392 constexpr size_type find(const_pointer s, size_type pos, size_type count)
393 const {
394 return find(basic_string_view(s, count), pos);
395 }
396
397 constexpr size_type find(const_pointer s, size_type pos = 0) const {
398 return find(basic_string_view(s), pos);
399 }
400
401 constexpr size_type rfind(basic_string_view v, size_type pos = npos)
402 const noexcept {
403#if __cpp_constexpr >= 201304
404 // if we are in C++14, write it iteratively. This is faster.
405 if (v.size() == 0) {
406 return pos <= size() ? pos : size();
407 }
408
409 if (v.size() <= size()) {
410 pos = guts::min(size() - v.size(), pos);
411 do {
412 if (v.at_(0) == at_(pos) &&
413 v.substr_(1).equals_(substr_(pos + 1, v.size() - 1))) {
414 return pos;
415 }
416 } while (pos-- > 0);
417 }
418 return npos;
419#else
420 // if we are in C++11, we need to do it recursively because of constexpr
421 // restrictions.
422 return (v.size() == 0) ? (pos <= size() ? pos : size())
423 : (v.size() > size()) ? npos
424 : (size() - v.size() < pos) ? rfind(v, size() - v.size())
425 : (v.at_(0) == at_(pos) &&
426 v.substr_(1).equals_(substr_(pos + 1, v.size() - 1)))
427 ? pos
428 : (pos == 0) ? npos
429 : rfind(v, pos - 1);
430#endif
431 }
432
433 constexpr size_type rfind(CharT ch, size_type pos = npos) const noexcept {
434 return find_last_if_(pos, charIsEqual_{ch});
435 }
436
437 constexpr size_type rfind(const_pointer s, size_type pos, size_type count)
438 const {
439 return rfind(basic_string_view(s, count), pos);
440 }
441
442 constexpr size_type rfind(const_pointer s, size_type pos = npos) const {
443 return rfind(basic_string_view(s), pos);
444 }
445
446 constexpr size_type find_first_of(basic_string_view v, size_type pos = 0)
447 const noexcept {
448 return find_first_if_(pos, stringViewContainsChar_{v});
449 }
450
451 constexpr size_type find_first_of(CharT ch, size_type pos = 0)
452 const noexcept {
453 return find_first_if_(pos, charIsEqual_{ch});
454 }
455
456 constexpr size_type find_first_of(
457 const_pointer s,
458 size_type pos,
459 size_type count) const {
460 return find_first_of(basic_string_view(s, count), pos);
461 }
462
463 constexpr size_type find_first_of(const_pointer s, size_type pos = 0) const {
464 return find_first_of(basic_string_view(s), pos);
465 }
466
467 constexpr size_type find_last_of(basic_string_view v, size_type pos = npos)
468 const noexcept {
469 return find_last_if_(pos, stringViewContainsChar_{v});
470 }
471
472 constexpr size_type find_last_of(CharT ch, size_type pos = npos)
473 const noexcept {
474 return find_last_if_(pos, charIsEqual_{ch});
475 }
476
477 constexpr size_type find_last_of(
478 const_pointer s,
479 size_type pos,
480 size_type count) const {
481 return find_last_of(basic_string_view(s, count), pos);
482 }
483
484 constexpr size_type find_last_of(const_pointer s, size_type pos = npos)
485 const {
486 return find_last_of(basic_string_view(s), pos);
487 }
488
489 constexpr size_type find_first_not_of(basic_string_view v, size_type pos = 0)
490 const noexcept {
491 return find_first_if_(pos, stringViewDoesNotContainChar_{v});
492 }
493
494 constexpr size_type find_first_not_of(CharT ch, size_type pos = 0)
495 const noexcept {
496 return find_first_if_(pos, charIsNotEqual_{ch});
497 }
498
499 constexpr size_type find_first_not_of(
500 const_pointer s,
501 size_type pos,
502 size_type count) const {
503 return find_first_not_of(basic_string_view(s, count), pos);
504 }
505
506 constexpr size_type find_first_not_of(const_pointer s, size_type pos = 0)
507 const {
508 return find_first_not_of(basic_string_view(s), pos);
509 }
510
511 constexpr size_type find_last_not_of(
512 basic_string_view v,
513 size_type pos = npos) const noexcept {
514 return find_last_if_(pos, stringViewDoesNotContainChar_{v});
515 }
516
517 constexpr size_type find_last_not_of(CharT ch, size_type pos = npos)
518 const noexcept {
519 return find_last_if_(pos, charIsNotEqual_{ch});
520 }
521
522 constexpr size_type find_last_not_of(
523 const_pointer s,
524 size_type pos,
525 size_type count) const {
526 return find_last_not_of(basic_string_view(s, count), pos);
527 }
528
529 constexpr size_type find_last_not_of(const_pointer s, size_type pos = npos)
530 const {
531 return find_last_not_of(basic_string_view(s), pos);
532 }
533
534 private:
535 static constexpr size_type strlen_(const_pointer str) noexcept {
536#if __cpp_constexpr >= 201304
537 // if we are in C++14, write it iteratively. This is faster.
538 const_pointer current = str;
539 while (*current != '\0') {
540 ++current;
541 }
542 return current - str;
543#else
544 // if we are in C++11, we need to do it recursively because of constexpr
545 // restrictions.
546 return (*str == '\0') ? 0 : 1 + strlen_(str + 1);
547#endif
548 }
549
550 constexpr const_reference at_(size_type pos) const noexcept {
551 return *(begin_ + pos);
552 }
553
554 constexpr basic_string_view substr_(size_type pos = 0, size_type count = npos)
555 const {
556 return basic_string_view{begin_ + pos, guts::min(count, size() - pos)};
557 }
558
559 template <class Condition>
560 constexpr size_type find_first_if_(size_type pos, Condition&& condition)
561 const noexcept {
562#if __cpp_constexpr >= 201304
563 // if we are in C++14, write it iteratively. This is faster.
564 if (pos + 1 <= size()) {
565 for (size_type cur = pos; cur < size(); ++cur) {
566 if (condition(at_(cur))) {
567 return cur;
568 }
569 }
570 }
571 return npos;
572#else
573 // if we are in C++11, we need to do it recursively because of constexpr
574 // restrictions.
575 return (pos + 1 > size()) ? npos
576 : condition(at_(pos))
577 ? pos
578 : find_first_if_(pos + 1, std::forward<Condition>(condition));
579#endif
580 }
581
582 template <class Condition>
583 constexpr size_type find_last_if_(size_type pos, Condition&& condition)
584 const noexcept {
585#if __cpp_constexpr >= 201304
586 // if we are in C++14, write it iteratively. This is faster.
587 if (size() > 0) {
588 pos = guts::min(size() - 1, pos);
589 do {
590 if (condition(at_(pos))) {
591 return pos;
592 }
593 } while (pos-- > 0);
594 }
595 return npos;
596#else
597 // if we are in C++11, we need to do it recursively because of constexpr
598 // restrictions.
599 return (size() == 0) ? npos
600 : (pos >= size())
601 ? find_last_if_(size() - 1, std::forward<Condition>(condition))
602 : condition(at_(pos)) ? pos
603 : (pos == 0)
604 ? npos
605 : find_last_if_(pos - 1, std::forward<Condition>(condition));
606#endif
607 }
608
609 constexpr bool equals_(basic_string_view rhs) const {
610 // We don't use string_view::compare() here but implement it manually
611 // because only looking at equality allows for more optimized code.
612#if defined(__GNUC__) && !defined(__CUDACC__)
613 return size() == rhs.size() &&
614 0 == __builtin_memcmp(data(), rhs.data(), size());
615#elif __cpp_constexpr >= 201304
616 // if we are in C++14, write it iteratively. This is faster than the
617 // recursive C++11 implementation below.
618 if (size() != rhs.size()) {
619 return false;
620 }
621 // Yes, memcmp would be laster than this loop, but memcmp isn't constexpr
622 // and I didn't feel like implementing a constexpr memcmp variant.
623 // TODO At some point this should probably be done, including tricks
624 // like comparing one machine word instead of a byte per iteration.
625 for (typename basic_string_view<CharT>::size_type pos = 0; pos < size();
626 ++pos) {
627 if (at_(pos) != rhs.at_(pos)) {
628 return false;
629 }
630 }
631 return true;
632#else
633 // if we are in C++11, we need to do it recursively because of constexpr
634 // restrictions.
635 return (size() != rhs.size()) ? false
636 : (size() == 0) ? true
637 : (front() != rhs.front()) ? false
638 : (substr_(1).equals_(rhs.substr_(1)));
639#endif
640 }
641
642 struct charIsEqual_ final {
643 CharT expected;
644 constexpr bool operator()(CharT actual) const noexcept {
645 return expected == actual;
646 }
647 };
648
649 struct charIsNotEqual_ final {
650 CharT expected;
651 constexpr bool operator()(CharT actual) const noexcept {
652 return expected != actual;
653 }
654 };
655
656 struct stringViewContainsChar_ final {
657 basic_string_view expected;
658 constexpr bool operator()(CharT ch) const noexcept {
659 return npos != expected.find(ch);
660 }
661 };
662
663 struct stringViewDoesNotContainChar_ final {
664 basic_string_view expected;
665 constexpr bool operator()(CharT ch) const noexcept {
666 return npos == expected.find(ch);
667 }
668 };
669
670 const_pointer begin_;
671 size_type size_{};
672};
673
674template <class CharT>
675const typename basic_string_view<CharT>::size_type
676 basic_string_view<CharT>::npos;
677
678template <class CharT>
679inline std::basic_ostream<CharT>& operator<<(
680 std::basic_ostream<CharT>& stream,
681 basic_string_view<CharT> sv) {
682 // The rules for operator<< are quite complex, so lets defer to the
683 // STL implementation. The std::string fallback might be a bit
684 // slower, but is better than getting it wrong.
685
686#if C10_HAS_STD_STRING_VIEW()
687 using std_string_type = ::std::basic_string_view<CharT>;
688#elif C10_HAS_STD_EXPERIMENTAL_STRING_VIEW()
689 using std_string_type = ::std::experimental::basic_string_view<CharT>;
690#else
691 using std_string_type = ::std::basic_string<CharT>;
692#endif
693 return stream << std_string_type(sv.data(), sv.size());
694}
695
696template <class CharT>
697constexpr inline void swap(
698 basic_string_view<CharT>& lhs,
699 basic_string_view<CharT>& rhs) {
700 lhs.swap(rhs);
701}
702
703using string_view = basic_string_view<char>;
704
705} // namespace c10
706
707namespace std {
708template <class CharT>
709struct hash<::c10::basic_string_view<CharT>> {
710 size_t operator()(::c10::basic_string_view<CharT> x) const {
711 // The standard says that std""string_view hashing must do the same as
712 // std::string hashing but leaves the details of std::string hashing
713 // up to the implementer. So, to be conformant, we need to re-use and
714 // existing STL type's hash function. The std::string fallback is probably
715 // slow but the only way to be conformant.
716
717#if C10_HAS_STD_STRING_VIEW()
718 using std_string_type = ::std::basic_string_view<CharT>;
719#elif C10_HAS_STD_EXPERIMENTAL_STRING_VIEW()
720 using std_string_type = ::std::experimental::basic_string_view<CharT>;
721#else
722 using std_string_type = ::std::basic_string<CharT>;
723#endif
724 return ::std::hash<std_string_type>{}(std_string_type(x.data(), x.size()));
725 }
726};
727} // namespace std
728
729C10_CLANG_DIAGNOSTIC_POP()
730