1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/string_ops.cc.
17
18#include <string>
19
20#include "tensorflow/core/framework/kernel_def_builder.h"
21#include "tensorflow/core/framework/op_kernel.h"
22#include "tensorflow/core/framework/tensor.h"
23#include "tensorflow/core/lib/core/errors.h"
24#include "tensorflow/core/lib/core/status.h"
25#include "tensorflow/core/lib/core/stringpiece.h"
26#include "tensorflow/core/lib/strings/str_util.h"
27
28namespace tensorflow {
29namespace {
30// Split input string `str` based on a character delimiter.
31// Returns a vector of StringPieces which are valid as long as input `str`
32// is valid.
33// Note: The single character delimiter is a common case and is implemented as
34// a series of finds in the input string, making it much more efficient than
35// SplitOnCharSet.
36template <typename Predicate>
37std::vector<StringPiece> SplitOnChar(const tstring& str, const char delim,
38 Predicate p) {
39 std::vector<StringPiece> result;
40 StringPiece text(str);
41 auto f = text.find(delim);
42 while (f != StringPiece::npos) {
43 StringPiece token = text.substr(0, f);
44 if (p(token)) {
45 result.emplace_back(token);
46 }
47 text.remove_prefix(f + 1);
48 f = text.find(delim);
49 }
50 if (p(text)) {
51 result.push_back(text);
52 }
53 return result;
54}
55
56// Split input string `str` based on a set of character delimiters.
57// Returns a vector of StringPieces which are valid as long as input `str`
58// is valid.
59// Based on str_util::Split.
60template <typename Predicate>
61std::vector<StringPiece> SplitOnCharSet(const tstring& str,
62 const tstring& delim_set, Predicate p) {
63 std::vector<StringPiece> result;
64 StringPiece text(str);
65 StringPiece delims(delim_set);
66 size_t token_start = 0;
67 for (size_t i = 0; i < text.size() + 1; i++) {
68 if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) {
69 StringPiece token(text.data() + token_start, i - token_start);
70 if (p(token)) {
71 result.emplace_back(token);
72 }
73 token_start = i + 1;
74 }
75 }
76 return result;
77}
78
79// Split input string `str` based on given delimiter.
80// Returns a vector of StringPieces which are valid as long as input `str`
81// is valid.
82template <typename Predicate>
83std::vector<StringPiece> Split(const tstring& str, const tstring& delimiter,
84 Predicate predicate) {
85 if (str.empty()) {
86 return std::vector<StringPiece>();
87 }
88 if (delimiter.empty()) {
89 std::vector<StringPiece> result;
90 result.resize(str.size());
91 for (size_t i = 0; i < str.size(); ++i) {
92 result[i] = StringPiece(str.data() + i, 1);
93 }
94 return result;
95 }
96 if (delimiter.size() == 1) {
97 return SplitOnChar(str, delimiter[0], predicate);
98 }
99 return SplitOnCharSet(str, delimiter, predicate);
100}
101
102std::vector<StringPiece> SplitV2(const tstring& str, StringPiece sep,
103 int maxsplit) {
104 // This SplitV2 method matches the behavior of python's str.split:
105 // If sep is given, consecutive delimiters are not grouped together
106 // and are deemed to delimit empty strings (for example, '1,,2'.split(',')
107 // returns ['1', '', '2']). The sep argument may consist of multiple
108 // characters (for example, '1<>2<>3'.split('<>') returns ['1', '2', '3']).
109 // Splitting an empty string with a specified separator returns [''].
110 //
111 // If sep is not specified or is None, a different splitting algorithm is
112 // applied: runs of consecutive whitespace are regarded as a single
113 // separator, and the result will contain no empty strings at the start or
114 // end if the string has leading or trailing whitespace. Consequently,
115 // splitting an empty string or a string consisting of just whitespace
116 // with a None separator returns [].
117
118 std::vector<StringPiece> result;
119
120 StringPiece text(str);
121 if (maxsplit == 0) {
122 result.emplace_back(text);
123 return result;
124 }
125
126 if (sep.empty()) {
127 StringPiece token;
128 // Remove leading whitespaces.
129 str_util::RemoveLeadingWhitespace(&text);
130 int split = 0;
131 while (str_util::ConsumeNonWhitespace(&text, &token)) {
132 result.push_back(token);
133 str_util::RemoveLeadingWhitespace(&text);
134 ++split;
135 if (maxsplit > 0 && split == maxsplit) {
136 result.push_back(text);
137 return result;
138 }
139 }
140 return result;
141 }
142 auto p = std::search(text.begin(), text.end(), sep.begin(), sep.end());
143 int split = 0;
144 while (p != text.end()) {
145 StringPiece token = text.substr(0, p - text.begin());
146 result.push_back(token);
147 text.remove_prefix(token.size());
148 text.remove_prefix(sep.size());
149 ++split;
150 if (maxsplit > 0 && split == maxsplit) {
151 result.push_back(StringPiece(text));
152 return result;
153 }
154 p = std::search(text.begin(), text.end(), sep.begin(), sep.end());
155 }
156 result.push_back(text);
157 return result;
158}
159
160} // namespace
161
162class StringSplitOp : public OpKernel {
163 public:
164 explicit StringSplitOp(OpKernelConstruction* context)
165 : OpKernel(context), skip_empty_(true) {
166 bool skip_empty;
167 // By default skip_empty_ is true. We only get the value from attr if it is
168 // available, so that it is backward compatible.
169 if (context->GetAttr("skip_empty", &skip_empty).ok()) {
170 skip_empty_ = skip_empty;
171 }
172 }
173
174 void Compute(OpKernelContext* ctx) override {
175 const Tensor* input_tensor;
176 OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
177 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_tensor->shape()),
178 errors::InvalidArgument("input must be a vector, got shape: ",
179 input_tensor->shape().DebugString()));
180
181 const auto input_vec = input_tensor->vec<tstring>();
182 const int64_t batch_size = input_vec.dimension(0);
183
184 const Tensor* delimiter_tensor;
185 OP_REQUIRES_OK(ctx, ctx->input("delimiter", &delimiter_tensor));
186 OP_REQUIRES(
187 ctx, TensorShapeUtils::IsScalar(delimiter_tensor->shape()),
188 errors::InvalidArgument("delimiter must be a scalar, got shape: ",
189 delimiter_tensor->shape().DebugString()));
190 const auto delimiter_vec = delimiter_tensor->flat<tstring>();
191 const tstring& delimiter = delimiter_vec(0);
192 // Empty delimiter means split the input character by character.
193 std::vector<StringPiece> tokens;
194 // Guess that we'll be unpacking a handful of tokens per example.
195 static constexpr int kReserveSize = 4;
196 tokens.reserve(batch_size * kReserveSize);
197
198 int64_t output_size = 0;
199 int64_t max_num_entries = 0;
200 std::vector<int64_t> num_indices(batch_size);
201 for (int64_t i = 0; i < batch_size; ++i) {
202 std::vector<StringPiece> parts =
203 skip_empty_ ? Split(input_vec(i), delimiter, str_util::SkipEmpty())
204 : Split(input_vec(i), delimiter, str_util::AllowEmpty());
205 int64_t n_entries = parts.size();
206 num_indices[i] = n_entries;
207 output_size += n_entries;
208 max_num_entries = std::max(max_num_entries, n_entries);
209 tokens.insert(tokens.end(), std::make_move_iterator(parts.begin()),
210 std::make_move_iterator(parts.end()));
211 }
212
213 Tensor* sp_indices_t;
214 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({output_size, 2}),
215 &sp_indices_t));
216 Tensor* sp_tokens_t;
217 OP_REQUIRES_OK(
218 ctx, ctx->allocate_output(1, TensorShape({output_size}), &sp_tokens_t));
219 Tensor* sp_shape_t;
220 OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({2}), &sp_shape_t));
221
222 auto sp_indices = sp_indices_t->matrix<int64_t>();
223 auto sp_tokens = sp_tokens_t->vec<tstring>();
224 auto sp_shape = sp_shape_t->vec<int64_t>();
225 sp_shape(0) = batch_size;
226 sp_shape(1) = max_num_entries;
227 size_t c = 0;
228 for (size_t i = 0; i < batch_size; ++i) {
229 for (size_t j = 0; j < num_indices[i]; ++j) {
230 sp_indices(c, 0) = i;
231 sp_indices(c, 1) = j;
232 sp_tokens(c).assign(tokens[c].data(), tokens[c].size());
233 ++c;
234 }
235 }
236 }
237
238 private:
239 bool skip_empty_;
240};
241
242class StringSplitV2Op : public OpKernel {
243 public:
244 explicit StringSplitV2Op(OpKernelConstruction* context)
245 : OpKernel(context), maxsplit_(-1) {
246 OP_REQUIRES_OK(context, context->GetAttr("maxsplit", &maxsplit_));
247 }
248
249 void Compute(OpKernelContext* ctx) override {
250 const Tensor* input_tensor;
251 OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
252 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_tensor->shape()),
253 errors::InvalidArgument("input must be a vector, got shape: ",
254 input_tensor->shape().DebugString()));
255
256 const auto input_vec = input_tensor->vec<tstring>();
257 const int64_t batch_size = input_vec.dimension(0);
258
259 const Tensor* sep_tensor;
260 OP_REQUIRES_OK(ctx, ctx->input("sep", &sep_tensor));
261 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sep_tensor->shape()),
262 errors::InvalidArgument("sep must be a scalar, got shape: ",
263 sep_tensor->shape().DebugString()));
264 const auto sep_vec = sep_tensor->flat<tstring>();
265 StringPiece sep(sep_vec(0));
266 std::vector<StringPiece> tokens;
267 // Guess that we'll be unpacking a handful of tokens per example.
268 static constexpr int kReserveSize = 4;
269 tokens.reserve(batch_size * kReserveSize);
270
271 int64_t output_size = 0;
272 int64_t max_num_entries = 0;
273 std::vector<int64_t> num_indices(batch_size);
274 for (int64_t i = 0; i < batch_size; ++i) {
275 std::vector<StringPiece> parts = SplitV2(input_vec(i), sep, maxsplit_);
276 int64_t n_entries = parts.size();
277 num_indices[i] = n_entries;
278 output_size += n_entries;
279 max_num_entries = std::max(max_num_entries, n_entries);
280 tokens.insert(tokens.end(), parts.begin(), parts.end());
281 }
282
283 Tensor* sp_indices_t;
284 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({output_size, 2}),
285 &sp_indices_t));
286 Tensor* sp_tokens_t;
287 OP_REQUIRES_OK(
288 ctx, ctx->allocate_output(1, TensorShape({output_size}), &sp_tokens_t));
289 Tensor* sp_shape_t;
290 OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({2}), &sp_shape_t));
291
292 auto sp_indices = sp_indices_t->matrix<int64_t>();
293 auto sp_tokens = sp_tokens_t->vec<tstring>();
294 auto sp_shape = sp_shape_t->vec<int64_t>();
295 sp_shape(0) = batch_size;
296 sp_shape(1) = max_num_entries;
297 size_t c = 0;
298 for (size_t i = 0; i < batch_size; ++i) {
299 for (size_t j = 0; j < num_indices[i]; ++j) {
300 sp_indices(c, 0) = i;
301 sp_indices(c, 1) = j;
302 sp_tokens(c).assign(tokens[c].data(), tokens[c].size());
303 ++c;
304 }
305 }
306 }
307
308 private:
309 int maxsplit_;
310};
311
312REGISTER_KERNEL_BUILDER(Name("StringSplit").Device(DEVICE_CPU), StringSplitOp);
313REGISTER_KERNEL_BUILDER(Name("StringSplitV2").Device(DEVICE_CPU),
314 StringSplitV2Op);
315
316} // namespace tensorflow
317