1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Util methods to read and write String tensors. |
17 | // String tensors are considered to be char tensor with protocol. |
18 | // [0, 3] 4 bytes: N, num of strings in the tensor in little endian. |
19 | // [(i+1)*4, (i+1)*4+3] 4 bytes: offset of i-th string in little endian, |
20 | // for i from 0 to N-1. |
21 | // [(N+1)*4, (N+1)*4+3] 4 bytes: length of the whole char buffer. |
22 | // [offset(i), offset(i+1) - 1] : content of i-th string. |
23 | // Example of a string tensor: |
24 | // [ |
25 | // 2, 0, 0, 0, # 2 strings. |
26 | // 16, 0, 0, 0, # 0-th string starts from index 16. |
27 | // 18, 0, 0, 0, # 1-st string starts from index 18. |
28 | // 18, 0, 0, 0, # total length of array. |
29 | // 'A', 'B', # 0-th string [16..17]: "AB" |
30 | // ] # 1-th string, empty |
31 | // |
32 | // A typical usage: |
33 | // In op.Eval(context, node): |
34 | // DynamicBuffer buf; |
35 | // # Add string "AB" to tensor, string is stored in dynamic buffer. |
36 | // buf.AddString("AB", 2); |
37 | // # Write content of DynamicBuffer to tensor in format of string tensor |
38 | // # described above. |
39 | // buf.WriteToTensor(tensor, nullptr) |
40 | |
41 | #ifndef TENSORFLOW_LITE_STRING_UTIL_H_ |
42 | #define TENSORFLOW_LITE_STRING_UTIL_H_ |
43 | |
44 | #include <stddef.h> |
45 | #include <stdint.h> |
46 | |
47 | #include <vector> |
48 | |
49 | #include "tensorflow/lite/c/common.h" |
50 | #include "tensorflow/lite/string_type.h" |
51 | |
52 | namespace tflite { |
53 | |
54 | // Convenient structure to store string pointer and length. |
55 | typedef struct { |
56 | const char* str; |
57 | int len; |
58 | } StringRef; |
59 | |
60 | // DynamicBuffer holds temporary buffer that will be used to create a dynamic |
61 | // tensor. A typical usage is to initialize a DynamicBuffer object, fill in |
62 | // content and call CreateStringTensor in op.Eval(). |
63 | class DynamicBuffer { |
64 | public: |
65 | DynamicBuffer() : offset_({0}) {} |
66 | |
67 | // Add string to dynamic buffer by resizing the buffer and copying the data. |
68 | void AddString(const StringRef& string); |
69 | |
70 | // Add string to dynamic buffer by resizing the buffer and copying the data. |
71 | void AddString(const char* str, size_t len); |
72 | |
73 | // Join a list of string with separator, and add as a single string to the |
74 | // buffer. |
75 | void AddJoinedString(const std::vector<StringRef>& strings, char separator); |
76 | void AddJoinedString(const std::vector<StringRef>& strings, |
77 | StringRef separator); |
78 | |
79 | // Fill content into a buffer and returns the number of bytes stored. |
80 | // The function allocates space for the buffer but does NOT take ownership. |
81 | int WriteToBuffer(char** buffer); |
82 | |
83 | // Fill content into a string tensor, with the given new_shape. The new shape |
84 | // must match the number of strings in this object. Caller relinquishes |
85 | // ownership of new_shape. If 'new_shape' is nullptr, keep the tensor's |
86 | // existing shape. |
87 | void WriteToTensor(TfLiteTensor* tensor, TfLiteIntArray* new_shape); |
88 | |
89 | // Fill content into a string tensor. Set shape to {num_strings}. |
90 | void WriteToTensorAsVector(TfLiteTensor* tensor); |
91 | |
92 | private: |
93 | // Data buffer to store contents of strings, not including headers. |
94 | std::vector<char> data_; |
95 | // Offset of the starting index of each string in data buffer. |
96 | std::vector<int32_t> offset_; |
97 | }; |
98 | |
99 | // Return num of strings in a String tensor. |
100 | int GetStringCount(const void* raw_buffer); |
101 | int GetStringCount(const TfLiteTensor* tensor); |
102 | |
103 | // Get String pointer and length of index-th string in tensor. |
104 | // NOTE: This will not create a copy of string data. |
105 | StringRef GetString(const void* raw_buffer, int string_index); |
106 | StringRef GetString(const TfLiteTensor* tensor, int string_index); |
107 | } // namespace tflite |
108 | |
109 | #endif // TENSORFLOW_LITE_STRING_UTIL_H_ |
110 | |