1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file tvm/support/span.h
23 * \brief Reimplementation of part of C++-20 style span.
24 */
25#ifndef TVM_SUPPORT_SPAN_H_
26#define TVM_SUPPORT_SPAN_H_
27
28#include <cstddef>
29#include <iterator>
30#include <type_traits>
31#include <vector>
32
33namespace tvm {
34namespace support {
35
36/*!
37 * \brief A partial implementation of the C++20 std::span.
38 *
39 * At the time of writing, TVM must compile against C++17.
40 */
41template <class T, class W>
42class Span {
43 public:
44 using value_type = W;
45 using const_W = typename std::add_const<W>::type;
46
47 template <class W1>
48 class iterator_base {
49 public:
50 using iterator_category = std::input_iterator_tag;
51 using value_type = W;
52 using difference_type = std::ptrdiff_t;
53 using pointer = const W*;
54 using reference = const W&;
55
56 inline iterator_base(T* ptr, T* end) : ptr_{ptr}, end_{end} { CHECK_GE(end, ptr); }
57
58 inline W1 operator*() { return W1(*ptr_); }
59
60 inline iterator_base<W1>& operator++() {
61 if (ptr_ != end_) ptr_++;
62 return *this;
63 }
64
65 inline bool operator==(iterator_base<W1> other) {
66 return ptr_ == other.ptr_ && end_ == other.end_;
67 }
68
69 inline bool operator!=(iterator_base<W1> other) { return !(*this == other); }
70
71 template <class X = W1, typename = std::enable_if_t<!std::is_const<X>::value>>
72 inline operator iterator_base<const_W>() const {
73 return iterator_base<const_W>(ptr_, end_);
74 }
75
76 private:
77 T* ptr_;
78 T* end_;
79 };
80
81 using iterator = iterator_base<W>;
82 using const_iterator = iterator_base<const_W>;
83
84 inline Span(T* begin, int num_elements) : begin_{begin}, end_{begin + num_elements} {}
85 inline Span(T* begin, T* end) : begin_{begin}, end_{end} {}
86
87 inline iterator begin() const { return iterator(begin_, end_); }
88
89 inline iterator end() const { return iterator(end_, end_); }
90
91 size_t size() const { return end_ - begin_; }
92
93 inline W operator[](int i) {
94 T* to_return = begin_ + i;
95 ICHECK_LT(to_return, end_) << "Span access out of bounds: " << i;
96 return W(*to_return);
97 }
98
99 inline operator std::vector<W>() { return std::vector<W>(begin(), end()); }
100
101 protected:
102 T* begin_;
103 T* end_;
104};
105
106} // namespace support
107} // namespace tvm
108
109#endif // TVM_SUPPORT_SPAN_H_
110