1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/runtime/container/string.h |
22 | * \brief Runtime String container types. |
23 | */ |
24 | #ifndef TVM_RUNTIME_CONTAINER_STRING_H_ |
25 | #define TVM_RUNTIME_CONTAINER_STRING_H_ |
26 | |
27 | #include <dmlc/logging.h> |
28 | #include <tvm/runtime/container/base.h> |
29 | #include <tvm/runtime/logging.h> |
30 | #include <tvm/runtime/memory.h> |
31 | #include <tvm/runtime/object.h> |
32 | |
33 | #include <algorithm> |
34 | #include <cstddef> |
35 | #include <cstring> |
36 | #include <initializer_list> |
37 | #include <memory> |
38 | #include <string> |
39 | #include <string_view> |
40 | #include <type_traits> |
41 | #include <unordered_map> |
42 | #include <utility> |
43 | #include <vector> |
44 | |
45 | namespace tvm { |
46 | namespace runtime { |
47 | |
48 | // Forward declare TVMArgValue |
49 | class TVMArgValue; |
50 | |
51 | /*! \brief An object representing string. It's POD type. */ |
52 | class StringObj : public Object { |
53 | public: |
54 | /*! \brief The pointer to string data. */ |
55 | const char* data; |
56 | |
57 | /*! \brief The length of the string object. */ |
58 | uint64_t size; |
59 | |
60 | static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString; |
61 | static constexpr const char* _type_key = "runtime.String" ; |
62 | TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object); |
63 | |
64 | private: |
65 | /*! \brief String object which is moved from std::string container. */ |
66 | class FromStd; |
67 | |
68 | friend class String; |
69 | }; |
70 | |
71 | /*! |
72 | * \brief Reference to string objects. |
73 | * |
74 | * \code |
75 | * |
76 | * // Example to create runtime String reference object from std::string |
77 | * std::string s = "hello world"; |
78 | * |
79 | * // You can create the reference from existing std::string |
80 | * String ref{std::move(s)}; |
81 | * |
82 | * // You can rebind the reference to another string. |
83 | * ref = std::string{"hello world2"}; |
84 | * |
85 | * // You can use the reference as hash map key |
86 | * std::unordered_map<String, int32_t> m; |
87 | * m[ref] = 1; |
88 | * |
89 | * // You can compare the reference object with other string objects |
90 | * assert(ref == "hello world", true); |
91 | * |
92 | * // You can convert the reference to std::string again |
93 | * string s2 = (string)ref; |
94 | * |
95 | * \endcode |
96 | */ |
97 | class String : public ObjectRef { |
98 | public: |
99 | /*! |
100 | * \brief Construct an empty string. |
101 | */ |
102 | String() : String(std::string()) {} |
103 | /*! |
104 | * \brief Construct a new String object |
105 | * |
106 | * \param other The moved/copied std::string object |
107 | * |
108 | * \note If user passes const reference, it will trigger copy. If it's rvalue, |
109 | * it will be moved into other. |
110 | */ |
111 | String(std::string other); // NOLINT(*) |
112 | |
113 | /*! |
114 | * \brief Construct a new String object |
115 | * |
116 | * \param other a char array. |
117 | */ |
118 | String(const char* other) // NOLINT(*) |
119 | : String(std::string(other)) {} |
120 | |
121 | /*! |
122 | * \brief Construct a new null object |
123 | */ |
124 | String(std::nullptr_t) // NOLINT(*) |
125 | : ObjectRef(nullptr) {} |
126 | |
127 | /*! |
128 | * \brief Change the value the reference object points to. |
129 | * |
130 | * \param other The value for the new String |
131 | * |
132 | */ |
133 | inline String& operator=(std::string other); |
134 | |
135 | /*! |
136 | * \brief Change the value the reference object points to. |
137 | * |
138 | * \param other The value for the new String |
139 | */ |
140 | inline String& operator=(const char* other); |
141 | |
142 | /*! |
143 | * \brief Compares this String object to other |
144 | * |
145 | * \param other The String to compare with. |
146 | * |
147 | * \return zero if both char sequences compare equal. negative if this appear |
148 | * before other, positive otherwise. |
149 | */ |
150 | int compare(const String& other) const { |
151 | return memncmp(data(), other.data(), size(), other.size()); |
152 | } |
153 | |
154 | /*! |
155 | * \brief Compares this String object to other |
156 | * |
157 | * \param other The string to compare with. |
158 | * |
159 | * \return zero if both char sequences compare equal. negative if this appear |
160 | * before other, positive otherwise. |
161 | */ |
162 | int compare(const std::string& other) const { |
163 | return memncmp(data(), other.data(), size(), other.size()); |
164 | } |
165 | |
166 | /*! |
167 | * \brief Compares this to other |
168 | * |
169 | * \param other The character array to compare with. |
170 | * |
171 | * \return zero if both char sequences compare equal. negative if this appear |
172 | * before other, positive otherwise. |
173 | */ |
174 | int compare(const char* other) const { |
175 | return memncmp(data(), other, size(), std::strlen(other)); |
176 | } |
177 | |
178 | /*! |
179 | * \brief Returns a pointer to the char array in the string. |
180 | * |
181 | * \return const char* |
182 | */ |
183 | const char* c_str() const { return get()->data; } |
184 | |
185 | /*! |
186 | * \brief Return the length of the string |
187 | * |
188 | * \return size_t string length |
189 | */ |
190 | size_t size() const { |
191 | const auto* ptr = get(); |
192 | return ptr->size; |
193 | } |
194 | |
195 | /*! |
196 | * \brief Return the length of the string |
197 | * |
198 | * \return size_t string length |
199 | */ |
200 | size_t length() const { return size(); } |
201 | |
202 | /*! |
203 | * \brief Retun if the string is empty |
204 | * |
205 | * \return true if empty, false otherwise. |
206 | */ |
207 | bool empty() const { return size() == 0; } |
208 | |
209 | /*! |
210 | * \brief Read an element. |
211 | * \param pos The position at which to read the character. |
212 | * |
213 | * \return The char at position |
214 | */ |
215 | char at(size_t pos) const { |
216 | if (pos < size()) { |
217 | return data()[pos]; |
218 | } else { |
219 | throw std::out_of_range("tvm::String index out of bounds" ); |
220 | } |
221 | } |
222 | |
223 | /*! |
224 | * \brief Return the data pointer |
225 | * |
226 | * \return const char* data pointer |
227 | */ |
228 | const char* data() const { return get()->data; } |
229 | |
230 | /*! |
231 | * \brief Convert String to an std::string object |
232 | * |
233 | * \return std::string |
234 | */ |
235 | operator std::string() const { return std::string{get()->data, size()}; } |
236 | |
237 | /*! |
238 | * \brief Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String |
239 | * \param val The value to be checked |
240 | * \return A boolean indicating if val can be converted to String |
241 | */ |
242 | inline static bool CanConvertFrom(const TVMArgValue& val); |
243 | |
244 | /*! |
245 | * \brief Hash the binary bytes |
246 | * \param data The data pointer |
247 | * \param size The size of the bytes. |
248 | * \return the hash value. |
249 | */ |
250 | static size_t HashBytes(const char* data, size_t size) { |
251 | // This function falls back to string copy with c++11 compiler and is |
252 | // recommended to be compiled with c++14 |
253 | return std::hash<std::string_view>()(std::string_view(data, size)); |
254 | } |
255 | |
256 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); |
257 | |
258 | private: |
259 | /*! |
260 | * \brief Compare two char sequence |
261 | * |
262 | * \param lhs Pointers to the char array to compare |
263 | * \param rhs Pointers to the char array to compare |
264 | * \param lhs_count Length of the char array to compare |
265 | * \param rhs_count Length of the char array to compare |
266 | * \return int zero if both char sequences compare equal. negative if this |
267 | * appear before other, positive otherwise. |
268 | */ |
269 | static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count); |
270 | |
271 | /*! |
272 | * \brief Concatenate two char sequences |
273 | * |
274 | * \param lhs Pointers to the lhs char array |
275 | * \param lhs_size The size of the lhs char array |
276 | * \param rhs Pointers to the rhs char array |
277 | * \param rhs_size The size of the rhs char array |
278 | * |
279 | * \return The concatenated char sequence |
280 | */ |
281 | static String Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) { |
282 | std::string ret(lhs, lhs_size); |
283 | ret.append(rhs, rhs_size); |
284 | return String(ret); |
285 | } |
286 | |
287 | // Overload + operator |
288 | friend String operator+(const String& lhs, const String& rhs); |
289 | friend String operator+(const String& lhs, const std::string& rhs); |
290 | friend String operator+(const std::string& lhs, const String& rhs); |
291 | friend String operator+(const String& lhs, const char* rhs); |
292 | friend String operator+(const char* lhs, const String& rhs); |
293 | |
294 | friend struct tvm::runtime::ObjectEqual; |
295 | }; |
296 | |
297 | /*! \brief An object representing string moved from std::string. */ |
298 | class StringObj::FromStd : public StringObj { |
299 | public: |
300 | /*! |
301 | * \brief Construct a new FromStd object |
302 | * |
303 | * \param other The moved/copied std::string object |
304 | * |
305 | * \note If user passes const reference, it will trigger copy. If it's rvalue, |
306 | * it will be moved into other. |
307 | */ |
308 | explicit FromStd(std::string other) : data_container{other} {} |
309 | |
310 | private: |
311 | /*! \brief Container that holds the memory. */ |
312 | std::string data_container; |
313 | |
314 | friend class String; |
315 | }; |
316 | |
317 | inline String::String(std::string other) { |
318 | auto ptr = make_object<StringObj::FromStd>(std::move(other)); |
319 | ptr->size = ptr->data_container.size(); |
320 | ptr->data = ptr->data_container.data(); |
321 | data_ = std::move(ptr); |
322 | } |
323 | |
324 | inline String& String::operator=(std::string other) { |
325 | String replace{std::move(other)}; |
326 | data_.swap(replace.data_); |
327 | return *this; |
328 | } |
329 | |
330 | inline String& String::operator=(const char* other) { return operator=(std::string(other)); } |
331 | |
332 | inline String operator+(const String& lhs, const String& rhs) { |
333 | size_t lhs_size = lhs.size(); |
334 | size_t rhs_size = rhs.size(); |
335 | return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); |
336 | } |
337 | |
338 | inline String operator+(const String& lhs, const std::string& rhs) { |
339 | size_t lhs_size = lhs.size(); |
340 | size_t rhs_size = rhs.size(); |
341 | return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); |
342 | } |
343 | |
344 | inline String operator+(const std::string& lhs, const String& rhs) { |
345 | size_t lhs_size = lhs.size(); |
346 | size_t rhs_size = rhs.size(); |
347 | return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); |
348 | } |
349 | |
350 | inline String operator+(const char* lhs, const String& rhs) { |
351 | size_t lhs_size = std::strlen(lhs); |
352 | size_t rhs_size = rhs.size(); |
353 | return String::Concat(lhs, lhs_size, rhs.data(), rhs_size); |
354 | } |
355 | |
356 | inline String operator+(const String& lhs, const char* rhs) { |
357 | size_t lhs_size = lhs.size(); |
358 | size_t rhs_size = std::strlen(rhs); |
359 | return String::Concat(lhs.data(), lhs_size, rhs, rhs_size); |
360 | } |
361 | |
362 | // Overload < operator |
363 | inline bool operator<(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) < 0; } |
364 | |
365 | inline bool operator<(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) > 0; } |
366 | |
367 | inline bool operator<(const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } |
368 | |
369 | inline bool operator<(const String& lhs, const char* rhs) { return lhs.compare(rhs) < 0; } |
370 | |
371 | inline bool operator<(const char* lhs, const String& rhs) { return rhs.compare(lhs) > 0; } |
372 | |
373 | // Overload > operator |
374 | inline bool operator>(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) > 0; } |
375 | |
376 | inline bool operator>(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) < 0; } |
377 | |
378 | inline bool operator>(const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } |
379 | |
380 | inline bool operator>(const String& lhs, const char* rhs) { return lhs.compare(rhs) > 0; } |
381 | |
382 | inline bool operator>(const char* lhs, const String& rhs) { return rhs.compare(lhs) < 0; } |
383 | |
384 | // Overload <= operator |
385 | inline bool operator<=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) <= 0; } |
386 | |
387 | inline bool operator<=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } |
388 | |
389 | inline bool operator<=(const String& lhs, const String& rhs) { return lhs.compare(rhs) <= 0; } |
390 | |
391 | inline bool operator<=(const String& lhs, const char* rhs) { return lhs.compare(rhs) <= 0; } |
392 | |
393 | inline bool operator<=(const char* lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } |
394 | |
395 | // Overload >= operator |
396 | inline bool operator>=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) >= 0; } |
397 | |
398 | inline bool operator>=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) <= 0; } |
399 | |
400 | inline bool operator>=(const String& lhs, const String& rhs) { return lhs.compare(rhs) >= 0; } |
401 | |
402 | inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(rhs) >= 0; } |
403 | |
404 | inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(rhs) <= 0; } |
405 | |
406 | // Overload == operator |
407 | inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; } |
408 | |
409 | inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; } |
410 | |
411 | inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } |
412 | |
413 | inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; } |
414 | |
415 | inline bool operator==(const char* lhs, const String& rhs) { return rhs.compare(lhs) == 0; } |
416 | |
417 | // Overload != operator |
418 | inline bool operator!=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) != 0; } |
419 | |
420 | inline bool operator!=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) != 0; } |
421 | |
422 | inline bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } |
423 | |
424 | inline bool operator!=(const String& lhs, const char* rhs) { return lhs.compare(rhs) != 0; } |
425 | |
426 | inline bool operator!=(const char* lhs, const String& rhs) { return rhs.compare(lhs) != 0; } |
427 | |
428 | inline std::ostream& operator<<(std::ostream& out, const String& input) { |
429 | out.write(input.data(), input.size()); |
430 | return out; |
431 | } |
432 | |
433 | inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { |
434 | if (lhs == rhs && lhs_count == rhs_count) return 0; |
435 | |
436 | for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { |
437 | if (lhs[i] < rhs[i]) return -1; |
438 | if (lhs[i] > rhs[i]) return 1; |
439 | } |
440 | if (lhs_count < rhs_count) { |
441 | return -1; |
442 | } else if (lhs_count > rhs_count) { |
443 | return 1; |
444 | } else { |
445 | return 0; |
446 | } |
447 | } |
448 | |
449 | inline size_t ObjectHash::operator()(const ObjectRef& a) const { |
450 | if (const auto* str = a.as<StringObj>()) { |
451 | return String::HashBytes(str->data, str->size); |
452 | } |
453 | return ObjectPtrHash()(a); |
454 | } |
455 | |
456 | inline bool ObjectEqual::operator()(const ObjectRef& a, const ObjectRef& b) const { |
457 | if (a.same_as(b)) { |
458 | return true; |
459 | } |
460 | if (const auto* str_a = a.as<StringObj>()) { |
461 | if (const auto* str_b = b.as<StringObj>()) { |
462 | return String::memncmp(str_a->data, str_b->data, str_a->size, str_b->size) == 0; |
463 | } |
464 | } |
465 | return false; |
466 | } |
467 | } // namespace runtime |
468 | |
469 | // expose the functions to the root namespace. |
470 | using runtime::String; |
471 | using runtime::StringObj; |
472 | } // namespace tvm |
473 | |
474 | namespace std { |
475 | |
476 | template <> |
477 | struct hash<::tvm::runtime::String> { |
478 | std::size_t operator()(const ::tvm::runtime::String& str) const { |
479 | return ::tvm::runtime::String::HashBytes(str.data(), str.size()); |
480 | } |
481 | }; |
482 | } // namespace std |
483 | |
484 | #endif // TVM_RUNTIME_CONTAINER_STRING_H_ |
485 | |