1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/runtime/container/base.h
22 * \brief Base utilities for common POD(plain old data) container types.
23 */
24#ifndef TVM_RUNTIME_CONTAINER_BASE_H_
25#define TVM_RUNTIME_CONTAINER_BASE_H_
26
27#include <dmlc/logging.h>
28#include <tvm/runtime/logging.h>
29#include <tvm/runtime/memory.h>
30#include <tvm/runtime/object.h>
31
32#include <algorithm>
33#include <initializer_list>
34#include <utility>
35
36namespace tvm {
37namespace runtime {
38
39/*! \brief String-aware ObjectRef equal functor */
40struct ObjectHash {
41 /*!
42 * \brief Calculate the hash code of an ObjectRef
43 * \param a The given ObjectRef
44 * \return Hash code of a, string hash for strings and pointer address otherwise.
45 */
46 size_t operator()(const ObjectRef& a) const;
47};
48
49/*! \brief String-aware ObjectRef hash functor */
50struct ObjectEqual {
51 /*!
52 * \brief Check if the two ObjectRef are equal
53 * \param a One ObjectRef
54 * \param b The other ObjectRef
55 * \return String equality if both are strings, pointer address equality otherwise.
56 */
57 bool operator()(const ObjectRef& a, const ObjectRef& b) const;
58};
59
60/*!
61 * \brief Base template for classes with array like memory layout.
62 *
63 * It provides general methods to access the memory. The memory
64 * layout is ArrayType + [ElemType]. The alignment of ArrayType
65 * and ElemType is handled by the memory allocator.
66 *
67 * \tparam ArrayType The array header type, contains object specific metadata.
68 * \tparam ElemType The type of objects stored in the array right after
69 * ArrayType.
70 *
71 * \code
72 * // Example usage of the template to define a simple array wrapper
73 * class ArrayObj : public InplaceArrayBase<ArrayObj, Elem> {
74 * public:
75 * // Wrap EmplaceInit to initialize the elements
76 * template <typename Iterator>
77 * void Init(Iterator begin, Iterator end) {
78 * size_t num_elems = std::distance(begin, end);
79 * auto it = begin;
80 * this->size = 0;
81 * for (size_t i = 0; i < num_elems; ++i) {
82 * InplaceArrayBase::EmplaceInit(i, *it++);
83 * this->size++;
84 * }
85 * }
86 * }
87 *
88 * void test_function() {
89 * vector<Elem> fields;
90 * auto ptr = make_inplace_array_object<ArrayObj, Elem>(fields.size());
91 * ptr->Init(fields.begin(), fields.end());
92 *
93 * // Access the 0th element in the array.
94 * assert(ptr->operator[](0) == fields[0]);
95 * }
96 *
97 * \endcode
98 */
99template <typename ArrayType, typename ElemType>
100class InplaceArrayBase {
101 public:
102 /*!
103 * \brief Access element at index
104 * \param idx The index of the element.
105 * \return Const reference to ElemType at the index.
106 */
107 const ElemType& operator[](size_t idx) const {
108 size_t size = Self()->GetSize();
109 ICHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n";
110 return *(reinterpret_cast<ElemType*>(AddressOf(idx)));
111 }
112
113 /*!
114 * \brief Access element at index
115 * \param idx The index of the element.
116 * \return Reference to ElemType at the index.
117 */
118 ElemType& operator[](size_t idx) {
119 size_t size = Self()->GetSize();
120 ICHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n";
121 return *(reinterpret_cast<ElemType*>(AddressOf(idx)));
122 }
123
124 /*!
125 * \brief Destroy the Inplace Array Base object
126 */
127 ~InplaceArrayBase() {
128 if (!(std::is_standard_layout<ElemType>::value && std::is_trivial<ElemType>::value)) {
129 size_t size = Self()->GetSize();
130 for (size_t i = 0; i < size; ++i) {
131 ElemType* fp = reinterpret_cast<ElemType*>(AddressOf(i));
132 fp->ElemType::~ElemType();
133 }
134 }
135 }
136
137 protected:
138 /*!
139 * \brief Construct a value in place with the arguments.
140 *
141 * \tparam Args Type parameters of the arguments.
142 * \param idx Index of the element.
143 * \param args Arguments to construct the new value.
144 *
145 * \note Please make sure ArrayType::GetSize returns 0 before first call of
146 * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds.
147 */
148 template <typename... Args>
149 void EmplaceInit(size_t idx, Args&&... args) {
150 void* field_ptr = AddressOf(idx);
151 new (field_ptr) ElemType(std::forward<Args>(args)...);
152 }
153
154 /*!
155 * \brief Return the self object for the array.
156 *
157 * \return Pointer to ArrayType.
158 */
159 inline ArrayType* Self() const {
160 return static_cast<ArrayType*>(const_cast<InplaceArrayBase*>(this));
161 }
162
163 /*!
164 * \brief Return the raw pointer to the element at idx.
165 *
166 * \param idx The index of the element.
167 * \return Raw pointer to the element.
168 */
169 void* AddressOf(size_t idx) const {
170 static_assert(
171 alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0,
172 "The size and alignment of ArrayType should respect "
173 "ElemType's alignment.");
174
175 size_t kDataStart = sizeof(ArrayType);
176 ArrayType* self = Self();
177 char* data_start = reinterpret_cast<char*>(self) + kDataStart;
178 return data_start + idx * sizeof(ElemType);
179 }
180};
181
182/*!
183 * \brief iterator adapter that adapts TIter to return another type.
184 * \tparam Converter a struct that contains converting function
185 * \tparam TIter the content iterator type.
186 */
187template <typename Converter, typename TIter>
188class IterAdapter {
189 public:
190 using difference_type = typename std::iterator_traits<TIter>::difference_type;
191 using value_type = typename Converter::ResultType;
192 using pointer = typename Converter::ResultType*;
193 using reference = typename Converter::ResultType&;
194 using iterator_category = typename std::iterator_traits<TIter>::iterator_category;
195
196 explicit IterAdapter(TIter iter) : iter_(iter) {}
197 IterAdapter& operator++() {
198 ++iter_;
199 return *this;
200 }
201 IterAdapter& operator--() {
202 --iter_;
203 return *this;
204 }
205 IterAdapter operator++(int) {
206 IterAdapter copy = *this;
207 ++iter_;
208 return copy;
209 }
210 IterAdapter operator--(int) {
211 IterAdapter copy = *this;
212 --iter_;
213 return copy;
214 }
215
216 IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); }
217
218 IterAdapter operator-(difference_type offset) const { return IterAdapter(iter_ - offset); }
219
220 template <typename T = IterAdapter>
221 typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value,
222 typename T::difference_type>::type inline
223 operator-(const IterAdapter& rhs) const {
224 return iter_ - rhs.iter_;
225 }
226
227 bool operator==(IterAdapter other) const { return iter_ == other.iter_; }
228 bool operator!=(IterAdapter other) const { return !(*this == other); }
229 const value_type operator*() const { return Converter::convert(*iter_); }
230
231 private:
232 TIter iter_;
233};
234
235/*!
236 * \brief iterator adapter that adapts TIter to return another type.
237 * \tparam Converter a struct that contains converting function
238 * \tparam TIter the content iterator type.
239 */
240template <typename Converter, typename TIter>
241class ReverseIterAdapter {
242 public:
243 using difference_type = typename std::iterator_traits<TIter>::difference_type;
244 using value_type = typename Converter::ResultType;
245 using pointer = typename Converter::ResultType*;
246 using reference = typename Converter::ResultType&; // NOLINT(*)
247 using iterator_category = typename std::iterator_traits<TIter>::iterator_category;
248
249 explicit ReverseIterAdapter(TIter iter) : iter_(iter) {}
250 ReverseIterAdapter& operator++() {
251 --iter_;
252 return *this;
253 }
254 ReverseIterAdapter& operator--() {
255 ++iter_;
256 return *this;
257 }
258 ReverseIterAdapter operator++(int) {
259 ReverseIterAdapter copy = *this;
260 --iter_;
261 return copy;
262 }
263 ReverseIterAdapter operator--(int) {
264 ReverseIterAdapter copy = *this;
265 ++iter_;
266 return copy;
267 }
268 ReverseIterAdapter operator+(difference_type offset) const {
269 return ReverseIterAdapter(iter_ - offset);
270 }
271
272 template <typename T = ReverseIterAdapter>
273 typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value,
274 typename T::difference_type>::type inline
275 operator-(const ReverseIterAdapter& rhs) const {
276 return rhs.iter_ - iter_;
277 }
278
279 bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; }
280 bool operator!=(ReverseIterAdapter other) const { return !(*this == other); }
281 const value_type operator*() const { return Converter::convert(*iter_); }
282
283 private:
284 TIter iter_;
285};
286
287} // namespace runtime
288
289// expose the functions to the root namespace.
290using runtime::Downcast;
291using runtime::IterAdapter;
292using runtime::make_object;
293using runtime::Object;
294using runtime::ObjectEqual;
295using runtime::ObjectHash;
296using runtime::ObjectPtr;
297using runtime::ObjectPtrEqual;
298using runtime::ObjectPtrHash;
299using runtime::ObjectRef;
300} // namespace tvm
301
302#endif // TVM_RUNTIME_CONTAINER_BASE_H_
303