1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/runtime/container/base.h |
22 | * \brief Base utilities for common POD(plain old data) container types. |
23 | */ |
24 | #ifndef TVM_RUNTIME_CONTAINER_BASE_H_ |
25 | #define TVM_RUNTIME_CONTAINER_BASE_H_ |
26 | |
27 | #include <dmlc/logging.h> |
28 | #include <tvm/runtime/logging.h> |
29 | #include <tvm/runtime/memory.h> |
30 | #include <tvm/runtime/object.h> |
31 | |
32 | #include <algorithm> |
33 | #include <initializer_list> |
34 | #include <utility> |
35 | |
36 | namespace tvm { |
37 | namespace runtime { |
38 | |
39 | /*! \brief String-aware ObjectRef equal functor */ |
40 | struct ObjectHash { |
41 | /*! |
42 | * \brief Calculate the hash code of an ObjectRef |
43 | * \param a The given ObjectRef |
44 | * \return Hash code of a, string hash for strings and pointer address otherwise. |
45 | */ |
46 | size_t operator()(const ObjectRef& a) const; |
47 | }; |
48 | |
49 | /*! \brief String-aware ObjectRef hash functor */ |
50 | struct ObjectEqual { |
51 | /*! |
52 | * \brief Check if the two ObjectRef are equal |
53 | * \param a One ObjectRef |
54 | * \param b The other ObjectRef |
55 | * \return String equality if both are strings, pointer address equality otherwise. |
56 | */ |
57 | bool operator()(const ObjectRef& a, const ObjectRef& b) const; |
58 | }; |
59 | |
60 | /*! |
61 | * \brief Base template for classes with array like memory layout. |
62 | * |
63 | * It provides general methods to access the memory. The memory |
64 | * layout is ArrayType + [ElemType]. The alignment of ArrayType |
65 | * and ElemType is handled by the memory allocator. |
66 | * |
67 | * \tparam ArrayType The array header type, contains object specific metadata. |
68 | * \tparam ElemType The type of objects stored in the array right after |
69 | * ArrayType. |
70 | * |
71 | * \code |
72 | * // Example usage of the template to define a simple array wrapper |
73 | * class ArrayObj : public InplaceArrayBase<ArrayObj, Elem> { |
74 | * public: |
75 | * // Wrap EmplaceInit to initialize the elements |
76 | * template <typename Iterator> |
77 | * void Init(Iterator begin, Iterator end) { |
78 | * size_t num_elems = std::distance(begin, end); |
79 | * auto it = begin; |
80 | * this->size = 0; |
81 | * for (size_t i = 0; i < num_elems; ++i) { |
82 | * InplaceArrayBase::EmplaceInit(i, *it++); |
83 | * this->size++; |
84 | * } |
85 | * } |
86 | * } |
87 | * |
88 | * void test_function() { |
89 | * vector<Elem> fields; |
90 | * auto ptr = make_inplace_array_object<ArrayObj, Elem>(fields.size()); |
91 | * ptr->Init(fields.begin(), fields.end()); |
92 | * |
93 | * // Access the 0th element in the array. |
94 | * assert(ptr->operator[](0) == fields[0]); |
95 | * } |
96 | * |
97 | * \endcode |
98 | */ |
99 | template <typename ArrayType, typename ElemType> |
100 | class InplaceArrayBase { |
101 | public: |
102 | /*! |
103 | * \brief Access element at index |
104 | * \param idx The index of the element. |
105 | * \return Const reference to ElemType at the index. |
106 | */ |
107 | const ElemType& operator[](size_t idx) const { |
108 | size_t size = Self()->GetSize(); |
109 | ICHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n" ; |
110 | return *(reinterpret_cast<ElemType*>(AddressOf(idx))); |
111 | } |
112 | |
113 | /*! |
114 | * \brief Access element at index |
115 | * \param idx The index of the element. |
116 | * \return Reference to ElemType at the index. |
117 | */ |
118 | ElemType& operator[](size_t idx) { |
119 | size_t size = Self()->GetSize(); |
120 | ICHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n" ; |
121 | return *(reinterpret_cast<ElemType*>(AddressOf(idx))); |
122 | } |
123 | |
124 | /*! |
125 | * \brief Destroy the Inplace Array Base object |
126 | */ |
127 | ~InplaceArrayBase() { |
128 | if (!(std::is_standard_layout<ElemType>::value && std::is_trivial<ElemType>::value)) { |
129 | size_t size = Self()->GetSize(); |
130 | for (size_t i = 0; i < size; ++i) { |
131 | ElemType* fp = reinterpret_cast<ElemType*>(AddressOf(i)); |
132 | fp->ElemType::~ElemType(); |
133 | } |
134 | } |
135 | } |
136 | |
137 | protected: |
138 | /*! |
139 | * \brief Construct a value in place with the arguments. |
140 | * |
141 | * \tparam Args Type parameters of the arguments. |
142 | * \param idx Index of the element. |
143 | * \param args Arguments to construct the new value. |
144 | * |
145 | * \note Please make sure ArrayType::GetSize returns 0 before first call of |
146 | * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds. |
147 | */ |
148 | template <typename... Args> |
149 | void EmplaceInit(size_t idx, Args&&... args) { |
150 | void* field_ptr = AddressOf(idx); |
151 | new (field_ptr) ElemType(std::forward<Args>(args)...); |
152 | } |
153 | |
154 | /*! |
155 | * \brief Return the self object for the array. |
156 | * |
157 | * \return Pointer to ArrayType. |
158 | */ |
159 | inline ArrayType* Self() const { |
160 | return static_cast<ArrayType*>(const_cast<InplaceArrayBase*>(this)); |
161 | } |
162 | |
163 | /*! |
164 | * \brief Return the raw pointer to the element at idx. |
165 | * |
166 | * \param idx The index of the element. |
167 | * \return Raw pointer to the element. |
168 | */ |
169 | void* AddressOf(size_t idx) const { |
170 | static_assert( |
171 | alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, |
172 | "The size and alignment of ArrayType should respect " |
173 | "ElemType's alignment." ); |
174 | |
175 | size_t kDataStart = sizeof(ArrayType); |
176 | ArrayType* self = Self(); |
177 | char* data_start = reinterpret_cast<char*>(self) + kDataStart; |
178 | return data_start + idx * sizeof(ElemType); |
179 | } |
180 | }; |
181 | |
182 | /*! |
183 | * \brief iterator adapter that adapts TIter to return another type. |
184 | * \tparam Converter a struct that contains converting function |
185 | * \tparam TIter the content iterator type. |
186 | */ |
187 | template <typename Converter, typename TIter> |
188 | class IterAdapter { |
189 | public: |
190 | using difference_type = typename std::iterator_traits<TIter>::difference_type; |
191 | using value_type = typename Converter::ResultType; |
192 | using pointer = typename Converter::ResultType*; |
193 | using reference = typename Converter::ResultType&; |
194 | using iterator_category = typename std::iterator_traits<TIter>::iterator_category; |
195 | |
196 | explicit IterAdapter(TIter iter) : iter_(iter) {} |
197 | IterAdapter& operator++() { |
198 | ++iter_; |
199 | return *this; |
200 | } |
201 | IterAdapter& operator--() { |
202 | --iter_; |
203 | return *this; |
204 | } |
205 | IterAdapter operator++(int) { |
206 | IterAdapter copy = *this; |
207 | ++iter_; |
208 | return copy; |
209 | } |
210 | IterAdapter operator--(int) { |
211 | IterAdapter copy = *this; |
212 | --iter_; |
213 | return copy; |
214 | } |
215 | |
216 | IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } |
217 | |
218 | IterAdapter operator-(difference_type offset) const { return IterAdapter(iter_ - offset); } |
219 | |
220 | template <typename T = IterAdapter> |
221 | typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value, |
222 | typename T::difference_type>::type inline |
223 | operator-(const IterAdapter& rhs) const { |
224 | return iter_ - rhs.iter_; |
225 | } |
226 | |
227 | bool operator==(IterAdapter other) const { return iter_ == other.iter_; } |
228 | bool operator!=(IterAdapter other) const { return !(*this == other); } |
229 | const value_type operator*() const { return Converter::convert(*iter_); } |
230 | |
231 | private: |
232 | TIter iter_; |
233 | }; |
234 | |
235 | /*! |
236 | * \brief iterator adapter that adapts TIter to return another type. |
237 | * \tparam Converter a struct that contains converting function |
238 | * \tparam TIter the content iterator type. |
239 | */ |
240 | template <typename Converter, typename TIter> |
241 | class ReverseIterAdapter { |
242 | public: |
243 | using difference_type = typename std::iterator_traits<TIter>::difference_type; |
244 | using value_type = typename Converter::ResultType; |
245 | using pointer = typename Converter::ResultType*; |
246 | using reference = typename Converter::ResultType&; // NOLINT(*) |
247 | using iterator_category = typename std::iterator_traits<TIter>::iterator_category; |
248 | |
249 | explicit ReverseIterAdapter(TIter iter) : iter_(iter) {} |
250 | ReverseIterAdapter& operator++() { |
251 | --iter_; |
252 | return *this; |
253 | } |
254 | ReverseIterAdapter& operator--() { |
255 | ++iter_; |
256 | return *this; |
257 | } |
258 | ReverseIterAdapter operator++(int) { |
259 | ReverseIterAdapter copy = *this; |
260 | --iter_; |
261 | return copy; |
262 | } |
263 | ReverseIterAdapter operator--(int) { |
264 | ReverseIterAdapter copy = *this; |
265 | ++iter_; |
266 | return copy; |
267 | } |
268 | ReverseIterAdapter operator+(difference_type offset) const { |
269 | return ReverseIterAdapter(iter_ - offset); |
270 | } |
271 | |
272 | template <typename T = ReverseIterAdapter> |
273 | typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value, |
274 | typename T::difference_type>::type inline |
275 | operator-(const ReverseIterAdapter& rhs) const { |
276 | return rhs.iter_ - iter_; |
277 | } |
278 | |
279 | bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; } |
280 | bool operator!=(ReverseIterAdapter other) const { return !(*this == other); } |
281 | const value_type operator*() const { return Converter::convert(*iter_); } |
282 | |
283 | private: |
284 | TIter iter_; |
285 | }; |
286 | |
287 | } // namespace runtime |
288 | |
289 | // expose the functions to the root namespace. |
290 | using runtime::Downcast; |
291 | using runtime::IterAdapter; |
292 | using runtime::make_object; |
293 | using runtime::Object; |
294 | using runtime::ObjectEqual; |
295 | using runtime::ObjectHash; |
296 | using runtime::ObjectPtr; |
297 | using runtime::ObjectPtrEqual; |
298 | using runtime::ObjectPtrHash; |
299 | using runtime::ObjectRef; |
300 | } // namespace tvm |
301 | |
302 | #endif // TVM_RUNTIME_CONTAINER_BASE_H_ |
303 | |