1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/runtime/container/string.h
22 * \brief Runtime String container types.
23 */
24#ifndef TVM_RUNTIME_CONTAINER_STRING_H_
25#define TVM_RUNTIME_CONTAINER_STRING_H_
26
27#include <dmlc/logging.h>
28#include <tvm/runtime/container/base.h>
29#include <tvm/runtime/logging.h>
30#include <tvm/runtime/memory.h>
31#include <tvm/runtime/object.h>
32
33#include <algorithm>
34#include <cstddef>
35#include <cstring>
36#include <initializer_list>
37#include <memory>
38#include <string>
39#include <string_view>
40#include <type_traits>
41#include <unordered_map>
42#include <utility>
43#include <vector>
44
45namespace tvm {
46namespace runtime {
47
48// Forward declare TVMArgValue
49class TVMArgValue;
50
51/*! \brief An object representing string. It's POD type. */
52class StringObj : public Object {
53 public:
54 /*! \brief The pointer to string data. */
55 const char* data;
56
57 /*! \brief The length of the string object. */
58 uint64_t size;
59
60 static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString;
61 static constexpr const char* _type_key = "runtime.String";
62 TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object);
63
64 private:
65 /*! \brief String object which is moved from std::string container. */
66 class FromStd;
67
68 friend class String;
69};
70
71/*!
72 * \brief Reference to string objects.
73 *
74 * \code
75 *
76 * // Example to create runtime String reference object from std::string
77 * std::string s = "hello world";
78 *
79 * // You can create the reference from existing std::string
80 * String ref{std::move(s)};
81 *
82 * // You can rebind the reference to another string.
83 * ref = std::string{"hello world2"};
84 *
85 * // You can use the reference as hash map key
86 * std::unordered_map<String, int32_t> m;
87 * m[ref] = 1;
88 *
89 * // You can compare the reference object with other string objects
90 * assert(ref == "hello world", true);
91 *
92 * // You can convert the reference to std::string again
93 * string s2 = (string)ref;
94 *
95 * \endcode
96 */
97class String : public ObjectRef {
98 public:
99 /*!
100 * \brief Construct an empty string.
101 */
102 String() : String(std::string()) {}
103 /*!
104 * \brief Construct a new String object
105 *
106 * \param other The moved/copied std::string object
107 *
108 * \note If user passes const reference, it will trigger copy. If it's rvalue,
109 * it will be moved into other.
110 */
111 String(std::string other); // NOLINT(*)
112
113 /*!
114 * \brief Construct a new String object
115 *
116 * \param other a char array.
117 */
118 String(const char* other) // NOLINT(*)
119 : String(std::string(other)) {}
120
121 /*!
122 * \brief Construct a new null object
123 */
124 String(std::nullptr_t) // NOLINT(*)
125 : ObjectRef(nullptr) {}
126
127 /*!
128 * \brief Change the value the reference object points to.
129 *
130 * \param other The value for the new String
131 *
132 */
133 inline String& operator=(std::string other);
134
135 /*!
136 * \brief Change the value the reference object points to.
137 *
138 * \param other The value for the new String
139 */
140 inline String& operator=(const char* other);
141
142 /*!
143 * \brief Compares this String object to other
144 *
145 * \param other The String to compare with.
146 *
147 * \return zero if both char sequences compare equal. negative if this appear
148 * before other, positive otherwise.
149 */
150 int compare(const String& other) const {
151 return memncmp(data(), other.data(), size(), other.size());
152 }
153
154 /*!
155 * \brief Compares this String object to other
156 *
157 * \param other The string to compare with.
158 *
159 * \return zero if both char sequences compare equal. negative if this appear
160 * before other, positive otherwise.
161 */
162 int compare(const std::string& other) const {
163 return memncmp(data(), other.data(), size(), other.size());
164 }
165
166 /*!
167 * \brief Compares this to other
168 *
169 * \param other The character array to compare with.
170 *
171 * \return zero if both char sequences compare equal. negative if this appear
172 * before other, positive otherwise.
173 */
174 int compare(const char* other) const {
175 return memncmp(data(), other, size(), std::strlen(other));
176 }
177
178 /*!
179 * \brief Returns a pointer to the char array in the string.
180 *
181 * \return const char*
182 */
183 const char* c_str() const { return get()->data; }
184
185 /*!
186 * \brief Return the length of the string
187 *
188 * \return size_t string length
189 */
190 size_t size() const {
191 const auto* ptr = get();
192 return ptr->size;
193 }
194
195 /*!
196 * \brief Return the length of the string
197 *
198 * \return size_t string length
199 */
200 size_t length() const { return size(); }
201
202 /*!
203 * \brief Retun if the string is empty
204 *
205 * \return true if empty, false otherwise.
206 */
207 bool empty() const { return size() == 0; }
208
209 /*!
210 * \brief Read an element.
211 * \param pos The position at which to read the character.
212 *
213 * \return The char at position
214 */
215 char at(size_t pos) const {
216 if (pos < size()) {
217 return data()[pos];
218 } else {
219 throw std::out_of_range("tvm::String index out of bounds");
220 }
221 }
222
223 /*!
224 * \brief Return the data pointer
225 *
226 * \return const char* data pointer
227 */
228 const char* data() const { return get()->data; }
229
230 /*!
231 * \brief Convert String to an std::string object
232 *
233 * \return std::string
234 */
235 operator std::string() const { return std::string{get()->data, size()}; }
236
237 /*!
238 * \brief Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String
239 * \param val The value to be checked
240 * \return A boolean indicating if val can be converted to String
241 */
242 inline static bool CanConvertFrom(const TVMArgValue& val);
243
244 /*!
245 * \brief Hash the binary bytes
246 * \param data The data pointer
247 * \param size The size of the bytes.
248 * \return the hash value.
249 */
250 static size_t HashBytes(const char* data, size_t size) {
251 // This function falls back to string copy with c++11 compiler and is
252 // recommended to be compiled with c++14
253 return std::hash<std::string_view>()(std::string_view(data, size));
254 }
255
256 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
257
258 private:
259 /*!
260 * \brief Compare two char sequence
261 *
262 * \param lhs Pointers to the char array to compare
263 * \param rhs Pointers to the char array to compare
264 * \param lhs_count Length of the char array to compare
265 * \param rhs_count Length of the char array to compare
266 * \return int zero if both char sequences compare equal. negative if this
267 * appear before other, positive otherwise.
268 */
269 static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count);
270
271 /*!
272 * \brief Concatenate two char sequences
273 *
274 * \param lhs Pointers to the lhs char array
275 * \param lhs_size The size of the lhs char array
276 * \param rhs Pointers to the rhs char array
277 * \param rhs_size The size of the rhs char array
278 *
279 * \return The concatenated char sequence
280 */
281 static String Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) {
282 std::string ret(lhs, lhs_size);
283 ret.append(rhs, rhs_size);
284 return String(ret);
285 }
286
287 // Overload + operator
288 friend String operator+(const String& lhs, const String& rhs);
289 friend String operator+(const String& lhs, const std::string& rhs);
290 friend String operator+(const std::string& lhs, const String& rhs);
291 friend String operator+(const String& lhs, const char* rhs);
292 friend String operator+(const char* lhs, const String& rhs);
293
294 friend struct tvm::runtime::ObjectEqual;
295};
296
297/*! \brief An object representing string moved from std::string. */
298class StringObj::FromStd : public StringObj {
299 public:
300 /*!
301 * \brief Construct a new FromStd object
302 *
303 * \param other The moved/copied std::string object
304 *
305 * \note If user passes const reference, it will trigger copy. If it's rvalue,
306 * it will be moved into other.
307 */
308 explicit FromStd(std::string other) : data_container{other} {}
309
310 private:
311 /*! \brief Container that holds the memory. */
312 std::string data_container;
313
314 friend class String;
315};
316
317inline String::String(std::string other) {
318 auto ptr = make_object<StringObj::FromStd>(std::move(other));
319 ptr->size = ptr->data_container.size();
320 ptr->data = ptr->data_container.data();
321 data_ = std::move(ptr);
322}
323
324inline String& String::operator=(std::string other) {
325 String replace{std::move(other)};
326 data_.swap(replace.data_);
327 return *this;
328}
329
330inline String& String::operator=(const char* other) { return operator=(std::string(other)); }
331
332inline String operator+(const String& lhs, const String& rhs) {
333 size_t lhs_size = lhs.size();
334 size_t rhs_size = rhs.size();
335 return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
336}
337
338inline String operator+(const String& lhs, const std::string& rhs) {
339 size_t lhs_size = lhs.size();
340 size_t rhs_size = rhs.size();
341 return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
342}
343
344inline String operator+(const std::string& lhs, const String& rhs) {
345 size_t lhs_size = lhs.size();
346 size_t rhs_size = rhs.size();
347 return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
348}
349
350inline String operator+(const char* lhs, const String& rhs) {
351 size_t lhs_size = std::strlen(lhs);
352 size_t rhs_size = rhs.size();
353 return String::Concat(lhs, lhs_size, rhs.data(), rhs_size);
354}
355
356inline String operator+(const String& lhs, const char* rhs) {
357 size_t lhs_size = lhs.size();
358 size_t rhs_size = std::strlen(rhs);
359 return String::Concat(lhs.data(), lhs_size, rhs, rhs_size);
360}
361
362// Overload < operator
363inline bool operator<(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) < 0; }
364
365inline bool operator<(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) > 0; }
366
367inline bool operator<(const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; }
368
369inline bool operator<(const String& lhs, const char* rhs) { return lhs.compare(rhs) < 0; }
370
371inline bool operator<(const char* lhs, const String& rhs) { return rhs.compare(lhs) > 0; }
372
373// Overload > operator
374inline bool operator>(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) > 0; }
375
376inline bool operator>(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) < 0; }
377
378inline bool operator>(const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; }
379
380inline bool operator>(const String& lhs, const char* rhs) { return lhs.compare(rhs) > 0; }
381
382inline bool operator>(const char* lhs, const String& rhs) { return rhs.compare(lhs) < 0; }
383
384// Overload <= operator
385inline bool operator<=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) <= 0; }
386
387inline bool operator<=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) >= 0; }
388
389inline bool operator<=(const String& lhs, const String& rhs) { return lhs.compare(rhs) <= 0; }
390
391inline bool operator<=(const String& lhs, const char* rhs) { return lhs.compare(rhs) <= 0; }
392
393inline bool operator<=(const char* lhs, const String& rhs) { return rhs.compare(lhs) >= 0; }
394
395// Overload >= operator
396inline bool operator>=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) >= 0; }
397
398inline bool operator>=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) <= 0; }
399
400inline bool operator>=(const String& lhs, const String& rhs) { return lhs.compare(rhs) >= 0; }
401
402inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(rhs) >= 0; }
403
404inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(rhs) <= 0; }
405
406// Overload == operator
407inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; }
408
409inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; }
410
411inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; }
412
413inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; }
414
415inline bool operator==(const char* lhs, const String& rhs) { return rhs.compare(lhs) == 0; }
416
417// Overload != operator
418inline bool operator!=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) != 0; }
419
420inline bool operator!=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) != 0; }
421
422inline bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; }
423
424inline bool operator!=(const String& lhs, const char* rhs) { return lhs.compare(rhs) != 0; }
425
426inline bool operator!=(const char* lhs, const String& rhs) { return rhs.compare(lhs) != 0; }
427
428inline std::ostream& operator<<(std::ostream& out, const String& input) {
429 out.write(input.data(), input.size());
430 return out;
431}
432
433inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) {
434 if (lhs == rhs && lhs_count == rhs_count) return 0;
435
436 for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) {
437 if (lhs[i] < rhs[i]) return -1;
438 if (lhs[i] > rhs[i]) return 1;
439 }
440 if (lhs_count < rhs_count) {
441 return -1;
442 } else if (lhs_count > rhs_count) {
443 return 1;
444 } else {
445 return 0;
446 }
447}
448
449inline size_t ObjectHash::operator()(const ObjectRef& a) const {
450 if (const auto* str = a.as<StringObj>()) {
451 return String::HashBytes(str->data, str->size);
452 }
453 return ObjectPtrHash()(a);
454}
455
456inline bool ObjectEqual::operator()(const ObjectRef& a, const ObjectRef& b) const {
457 if (a.same_as(b)) {
458 return true;
459 }
460 if (const auto* str_a = a.as<StringObj>()) {
461 if (const auto* str_b = b.as<StringObj>()) {
462 return String::memncmp(str_a->data, str_b->data, str_a->size, str_b->size) == 0;
463 }
464 }
465 return false;
466}
467} // namespace runtime
468
469// expose the functions to the root namespace.
470using runtime::String;
471using runtime::StringObj;
472} // namespace tvm
473
474namespace std {
475
476template <>
477struct hash<::tvm::runtime::String> {
478 std::size_t operator()(const ::tvm::runtime::String& str) const {
479 return ::tvm::runtime::String::HashBytes(str.data(), str.size());
480 }
481};
482} // namespace std
483
484#endif // TVM_RUNTIME_CONTAINER_STRING_H_
485