1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/string_ops.cc. |
17 | |
18 | #include <string> |
19 | |
20 | #include "tensorflow/core/framework/kernel_def_builder.h" |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/framework/tensor.h" |
23 | #include "tensorflow/core/lib/core/errors.h" |
24 | #include "tensorflow/core/lib/core/status.h" |
25 | #include "tensorflow/core/lib/core/stringpiece.h" |
26 | #include "tensorflow/core/lib/strings/str_util.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace { |
30 | // Split input string `str` based on a character delimiter. |
31 | // Returns a vector of StringPieces which are valid as long as input `str` |
32 | // is valid. |
33 | // Note: The single character delimiter is a common case and is implemented as |
34 | // a series of finds in the input string, making it much more efficient than |
35 | // SplitOnCharSet. |
36 | template <typename Predicate> |
37 | std::vector<StringPiece> SplitOnChar(const tstring& str, const char delim, |
38 | Predicate p) { |
39 | std::vector<StringPiece> result; |
40 | StringPiece text(str); |
41 | auto f = text.find(delim); |
42 | while (f != StringPiece::npos) { |
43 | StringPiece token = text.substr(0, f); |
44 | if (p(token)) { |
45 | result.emplace_back(token); |
46 | } |
47 | text.remove_prefix(f + 1); |
48 | f = text.find(delim); |
49 | } |
50 | if (p(text)) { |
51 | result.push_back(text); |
52 | } |
53 | return result; |
54 | } |
55 | |
56 | // Split input string `str` based on a set of character delimiters. |
57 | // Returns a vector of StringPieces which are valid as long as input `str` |
58 | // is valid. |
59 | // Based on str_util::Split. |
60 | template <typename Predicate> |
61 | std::vector<StringPiece> SplitOnCharSet(const tstring& str, |
62 | const tstring& delim_set, Predicate p) { |
63 | std::vector<StringPiece> result; |
64 | StringPiece text(str); |
65 | StringPiece delims(delim_set); |
66 | size_t token_start = 0; |
67 | for (size_t i = 0; i < text.size() + 1; i++) { |
68 | if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) { |
69 | StringPiece token(text.data() + token_start, i - token_start); |
70 | if (p(token)) { |
71 | result.emplace_back(token); |
72 | } |
73 | token_start = i + 1; |
74 | } |
75 | } |
76 | return result; |
77 | } |
78 | |
79 | // Split input string `str` based on given delimiter. |
80 | // Returns a vector of StringPieces which are valid as long as input `str` |
81 | // is valid. |
82 | template <typename Predicate> |
83 | std::vector<StringPiece> Split(const tstring& str, const tstring& delimiter, |
84 | Predicate predicate) { |
85 | if (str.empty()) { |
86 | return std::vector<StringPiece>(); |
87 | } |
88 | if (delimiter.empty()) { |
89 | std::vector<StringPiece> result; |
90 | result.resize(str.size()); |
91 | for (size_t i = 0; i < str.size(); ++i) { |
92 | result[i] = StringPiece(str.data() + i, 1); |
93 | } |
94 | return result; |
95 | } |
96 | if (delimiter.size() == 1) { |
97 | return SplitOnChar(str, delimiter[0], predicate); |
98 | } |
99 | return SplitOnCharSet(str, delimiter, predicate); |
100 | } |
101 | |
102 | std::vector<StringPiece> SplitV2(const tstring& str, StringPiece sep, |
103 | int maxsplit) { |
104 | // This SplitV2 method matches the behavior of python's str.split: |
105 | // If sep is given, consecutive delimiters are not grouped together |
106 | // and are deemed to delimit empty strings (for example, '1,,2'.split(',') |
107 | // returns ['1', '', '2']). The sep argument may consist of multiple |
108 | // characters (for example, '1<>2<>3'.split('<>') returns ['1', '2', '3']). |
109 | // Splitting an empty string with a specified separator returns ['']. |
110 | // |
111 | // If sep is not specified or is None, a different splitting algorithm is |
112 | // applied: runs of consecutive whitespace are regarded as a single |
113 | // separator, and the result will contain no empty strings at the start or |
114 | // end if the string has leading or trailing whitespace. Consequently, |
115 | // splitting an empty string or a string consisting of just whitespace |
116 | // with a None separator returns []. |
117 | |
118 | std::vector<StringPiece> result; |
119 | |
120 | StringPiece text(str); |
121 | if (maxsplit == 0) { |
122 | result.emplace_back(text); |
123 | return result; |
124 | } |
125 | |
126 | if (sep.empty()) { |
127 | StringPiece token; |
128 | // Remove leading whitespaces. |
129 | str_util::RemoveLeadingWhitespace(&text); |
130 | int split = 0; |
131 | while (str_util::ConsumeNonWhitespace(&text, &token)) { |
132 | result.push_back(token); |
133 | str_util::RemoveLeadingWhitespace(&text); |
134 | ++split; |
135 | if (maxsplit > 0 && split == maxsplit) { |
136 | result.push_back(text); |
137 | return result; |
138 | } |
139 | } |
140 | return result; |
141 | } |
142 | auto p = std::search(text.begin(), text.end(), sep.begin(), sep.end()); |
143 | int split = 0; |
144 | while (p != text.end()) { |
145 | StringPiece token = text.substr(0, p - text.begin()); |
146 | result.push_back(token); |
147 | text.remove_prefix(token.size()); |
148 | text.remove_prefix(sep.size()); |
149 | ++split; |
150 | if (maxsplit > 0 && split == maxsplit) { |
151 | result.push_back(StringPiece(text)); |
152 | return result; |
153 | } |
154 | p = std::search(text.begin(), text.end(), sep.begin(), sep.end()); |
155 | } |
156 | result.push_back(text); |
157 | return result; |
158 | } |
159 | |
160 | } // namespace |
161 | |
162 | class StringSplitOp : public OpKernel { |
163 | public: |
164 | explicit StringSplitOp(OpKernelConstruction* context) |
165 | : OpKernel(context), skip_empty_(true) { |
166 | bool skip_empty; |
167 | // By default skip_empty_ is true. We only get the value from attr if it is |
168 | // available, so that it is backward compatible. |
169 | if (context->GetAttr("skip_empty" , &skip_empty).ok()) { |
170 | skip_empty_ = skip_empty; |
171 | } |
172 | } |
173 | |
174 | void Compute(OpKernelContext* ctx) override { |
175 | const Tensor* input_tensor; |
176 | OP_REQUIRES_OK(ctx, ctx->input("input" , &input_tensor)); |
177 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_tensor->shape()), |
178 | errors::InvalidArgument("input must be a vector, got shape: " , |
179 | input_tensor->shape().DebugString())); |
180 | |
181 | const auto input_vec = input_tensor->vec<tstring>(); |
182 | const int64_t batch_size = input_vec.dimension(0); |
183 | |
184 | const Tensor* delimiter_tensor; |
185 | OP_REQUIRES_OK(ctx, ctx->input("delimiter" , &delimiter_tensor)); |
186 | OP_REQUIRES( |
187 | ctx, TensorShapeUtils::IsScalar(delimiter_tensor->shape()), |
188 | errors::InvalidArgument("delimiter must be a scalar, got shape: " , |
189 | delimiter_tensor->shape().DebugString())); |
190 | const auto delimiter_vec = delimiter_tensor->flat<tstring>(); |
191 | const tstring& delimiter = delimiter_vec(0); |
192 | // Empty delimiter means split the input character by character. |
193 | std::vector<StringPiece> tokens; |
194 | // Guess that we'll be unpacking a handful of tokens per example. |
195 | static constexpr int kReserveSize = 4; |
196 | tokens.reserve(batch_size * kReserveSize); |
197 | |
198 | int64_t output_size = 0; |
199 | int64_t max_num_entries = 0; |
200 | std::vector<int64_t> num_indices(batch_size); |
201 | for (int64_t i = 0; i < batch_size; ++i) { |
202 | std::vector<StringPiece> parts = |
203 | skip_empty_ ? Split(input_vec(i), delimiter, str_util::SkipEmpty()) |
204 | : Split(input_vec(i), delimiter, str_util::AllowEmpty()); |
205 | int64_t n_entries = parts.size(); |
206 | num_indices[i] = n_entries; |
207 | output_size += n_entries; |
208 | max_num_entries = std::max(max_num_entries, n_entries); |
209 | tokens.insert(tokens.end(), std::make_move_iterator(parts.begin()), |
210 | std::make_move_iterator(parts.end())); |
211 | } |
212 | |
213 | Tensor* sp_indices_t; |
214 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({output_size, 2}), |
215 | &sp_indices_t)); |
216 | Tensor* sp_tokens_t; |
217 | OP_REQUIRES_OK( |
218 | ctx, ctx->allocate_output(1, TensorShape({output_size}), &sp_tokens_t)); |
219 | Tensor* sp_shape_t; |
220 | OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({2}), &sp_shape_t)); |
221 | |
222 | auto sp_indices = sp_indices_t->matrix<int64_t>(); |
223 | auto sp_tokens = sp_tokens_t->vec<tstring>(); |
224 | auto sp_shape = sp_shape_t->vec<int64_t>(); |
225 | sp_shape(0) = batch_size; |
226 | sp_shape(1) = max_num_entries; |
227 | size_t c = 0; |
228 | for (size_t i = 0; i < batch_size; ++i) { |
229 | for (size_t j = 0; j < num_indices[i]; ++j) { |
230 | sp_indices(c, 0) = i; |
231 | sp_indices(c, 1) = j; |
232 | sp_tokens(c).assign(tokens[c].data(), tokens[c].size()); |
233 | ++c; |
234 | } |
235 | } |
236 | } |
237 | |
238 | private: |
239 | bool skip_empty_; |
240 | }; |
241 | |
242 | class StringSplitV2Op : public OpKernel { |
243 | public: |
244 | explicit StringSplitV2Op(OpKernelConstruction* context) |
245 | : OpKernel(context), maxsplit_(-1) { |
246 | OP_REQUIRES_OK(context, context->GetAttr("maxsplit" , &maxsplit_)); |
247 | } |
248 | |
249 | void Compute(OpKernelContext* ctx) override { |
250 | const Tensor* input_tensor; |
251 | OP_REQUIRES_OK(ctx, ctx->input("input" , &input_tensor)); |
252 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_tensor->shape()), |
253 | errors::InvalidArgument("input must be a vector, got shape: " , |
254 | input_tensor->shape().DebugString())); |
255 | |
256 | const auto input_vec = input_tensor->vec<tstring>(); |
257 | const int64_t batch_size = input_vec.dimension(0); |
258 | |
259 | const Tensor* sep_tensor; |
260 | OP_REQUIRES_OK(ctx, ctx->input("sep" , &sep_tensor)); |
261 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sep_tensor->shape()), |
262 | errors::InvalidArgument("sep must be a scalar, got shape: " , |
263 | sep_tensor->shape().DebugString())); |
264 | const auto sep_vec = sep_tensor->flat<tstring>(); |
265 | StringPiece sep(sep_vec(0)); |
266 | std::vector<StringPiece> tokens; |
267 | // Guess that we'll be unpacking a handful of tokens per example. |
268 | static constexpr int kReserveSize = 4; |
269 | tokens.reserve(batch_size * kReserveSize); |
270 | |
271 | int64_t output_size = 0; |
272 | int64_t max_num_entries = 0; |
273 | std::vector<int64_t> num_indices(batch_size); |
274 | for (int64_t i = 0; i < batch_size; ++i) { |
275 | std::vector<StringPiece> parts = SplitV2(input_vec(i), sep, maxsplit_); |
276 | int64_t n_entries = parts.size(); |
277 | num_indices[i] = n_entries; |
278 | output_size += n_entries; |
279 | max_num_entries = std::max(max_num_entries, n_entries); |
280 | tokens.insert(tokens.end(), parts.begin(), parts.end()); |
281 | } |
282 | |
283 | Tensor* sp_indices_t; |
284 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({output_size, 2}), |
285 | &sp_indices_t)); |
286 | Tensor* sp_tokens_t; |
287 | OP_REQUIRES_OK( |
288 | ctx, ctx->allocate_output(1, TensorShape({output_size}), &sp_tokens_t)); |
289 | Tensor* sp_shape_t; |
290 | OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({2}), &sp_shape_t)); |
291 | |
292 | auto sp_indices = sp_indices_t->matrix<int64_t>(); |
293 | auto sp_tokens = sp_tokens_t->vec<tstring>(); |
294 | auto sp_shape = sp_shape_t->vec<int64_t>(); |
295 | sp_shape(0) = batch_size; |
296 | sp_shape(1) = max_num_entries; |
297 | size_t c = 0; |
298 | for (size_t i = 0; i < batch_size; ++i) { |
299 | for (size_t j = 0; j < num_indices[i]; ++j) { |
300 | sp_indices(c, 0) = i; |
301 | sp_indices(c, 1) = j; |
302 | sp_tokens(c).assign(tokens[c].data(), tokens[c].size()); |
303 | ++c; |
304 | } |
305 | } |
306 | } |
307 | |
308 | private: |
309 | int maxsplit_; |
310 | }; |
311 | |
312 | REGISTER_KERNEL_BUILDER(Name("StringSplit" ).Device(DEVICE_CPU), StringSplitOp); |
313 | REGISTER_KERNEL_BUILDER(Name("StringSplitV2" ).Device(DEVICE_CPU), |
314 | StringSplitV2Op); |
315 | |
316 | } // namespace tensorflow |
317 | |