1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <cstddef>
17#include <cstdlib>
18#include <string>
19
20#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21#include "tensorflow/core/framework/bounds_check.h"
22#include "tensorflow/core/framework/kernel_def_builder.h"
23#include "tensorflow/core/framework/op.h"
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_shape.h"
27#include "tensorflow/core/framework/tensor_types.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/kernels/string_util.h"
30#include "tensorflow/core/lib/core/errors.h"
31#include "tensorflow/core/lib/core/stringpiece.h"
32#include "tensorflow/core/platform/types.h"
33#include "tensorflow/core/util/bcast.h"
34
35namespace tensorflow {
36
37// Position/length can be 32 or 64-bit integers
38template <typename T>
39class SubstrOp : public OpKernel {
40 public:
41 explicit SubstrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
42 string unit;
43 OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit));
44 OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_));
45 }
46
47 void Compute(OpKernelContext* context) override {
48 // Get inputs
49 const Tensor& input_tensor = context->input(0);
50 const Tensor& pos_tensor = context->input(1);
51 const Tensor& len_tensor = context->input(2);
52 const TensorShape& input_shape = input_tensor.shape();
53 const TensorShape& pos_shape = pos_tensor.shape();
54 const TensorShape& len_shape = len_tensor.shape();
55 OP_REQUIRES(context, (pos_shape == len_shape),
56 errors::InvalidArgument(
57 "pos and len should have the same shape, got: ",
58 pos_shape.DebugString(), " vs. ", len_shape.DebugString()));
59
60 bool is_scalar = TensorShapeUtils::IsScalar(pos_shape);
61
62 if (is_scalar || input_shape == pos_shape) {
63 // pos/len are either scalar or match the shape of input_tensor
64 // Do not need to do broadcasting
65
66 // Reshape input
67 auto input = input_tensor.flat<tstring>();
68 // Allocate output
69 Tensor* output_tensor = nullptr;
70 OP_REQUIRES_OK(context,
71 context->allocate_output("output", input_tensor.shape(),
72 &output_tensor));
73 auto output = output_tensor->flat<tstring>();
74 if (is_scalar) {
75 // Perform Op with scalar pos/len
76 const T pos =
77 tensorflow::internal::SubtleMustCopy(pos_tensor.scalar<T>()());
78 const T len =
79 tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
80 for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
81 StringPiece in(input(i));
82 T byte_pos = pos;
83 T byte_len = len;
84 switch (unit_) {
85 case CharUnit::UTF8_CHAR:
86 OP_REQUIRES(
87 context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
88 errors::InvalidArgument("pos ", pos, " out of range for ",
89 "string at index ", i));
90 break;
91 case CharUnit::BYTE:
92 byte_pos = AdjustedPosIndex(byte_pos, in);
93 OP_REQUIRES(
94 context, FastBoundsCheck(byte_pos, in.size() + 1),
95 errors::InvalidArgument("pos ", pos, " out of range for ",
96 "string b'", in, "' at index ", i));
97 }
98 StringPiece sub_in = in.substr(byte_pos, byte_len);
99 output(i).assign(sub_in.data(), sub_in.size());
100 }
101 } else {
102 // Perform Op element-wise with tensor pos/len
103 auto pos_flat = pos_tensor.flat<T>();
104 auto len_flat = len_tensor.flat<T>();
105 for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
106 StringPiece in(input(i));
107 const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
108 const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
109 T byte_pos = pos;
110 T byte_len = len;
111 switch (unit_) {
112 case CharUnit::UTF8_CHAR:
113 OP_REQUIRES(
114 context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
115 errors::InvalidArgument("pos ", pos, " out of range for ",
116 "string at index ", i));
117 break;
118 case CharUnit::BYTE:
119 byte_pos = AdjustedPosIndex(byte_pos, in);
120 OP_REQUIRES(
121 context, FastBoundsCheck(byte_pos, in.size() + 1),
122 errors::InvalidArgument("pos ", pos, " out of range for ",
123 "string b'", in, "' at index ", i));
124 }
125 StringPiece sub_in = in.substr(byte_pos, byte_len);
126 output(i).assign(sub_in.data(), sub_in.size());
127 }
128 }
129 } else {
130 // Perform op with broadcasting
131 // TODO: Use ternary broadcasting for once available in Eigen. Current
132 // implementation iterates through broadcasted ops element-wise;
133 // this should be parallelized.
134
135 // Create BCast helper with shape of input and pos/len
136 BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(pos_shape),
137 /*fewer_dims_optimization*/ false);
138 OP_REQUIRES(context, bcast.IsValid(),
139 errors::InvalidArgument(
140 "Incompatible shapes: ", input_shape.DebugString(),
141 " vs. ", pos_shape.DebugString()));
142 TensorShape output_shape = BCast::ToShape(bcast.result_shape());
143 int ndims = output_shape.dims();
144 Tensor* output_tensor = nullptr;
145 OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
146 &output_tensor));
147 switch (ndims) {
148 case 1: {
149 // Reshape tensors according to BCast results
150 auto input = input_tensor.shaped<tstring, 1>(bcast.x_reshape());
151 auto output = output_tensor->shaped<tstring, 1>(bcast.result_shape());
152 auto pos_shaped = pos_tensor.shaped<T, 1>(bcast.y_reshape());
153 auto len_shaped = len_tensor.shaped<T, 1>(bcast.y_reshape());
154
155 // Allocate temporary buffer for broadcasted position tensor
156 Tensor pos_buffer;
157 OP_REQUIRES_OK(context,
158 context->allocate_temp(DataTypeToEnum<T>::v(),
159 output_shape, &pos_buffer));
160 typename TTypes<T, 1>::Tensor pos_bcast(
161 pos_buffer.shaped<T, 1>(bcast.result_shape()));
162 pos_bcast =
163 pos_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast()));
164
165 // Allocate temporary buffer for broadcasted length tensor
166 Tensor len_buffer;
167 OP_REQUIRES_OK(context,
168 context->allocate_temp(DataTypeToEnum<T>::v(),
169 output_shape, &len_buffer));
170 typename TTypes<T, 1>::Tensor len_bcast(
171 len_buffer.shaped<T, 1>(bcast.result_shape()));
172 len_bcast =
173 len_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast()));
174
175 // Iterate through broadcasted tensors and perform substr
176 for (int i = 0; i < output_shape.dim_size(0); ++i) {
177 StringPiece in(input(input.dimension(0) > 1 ? i : 0));
178 const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
179 const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
180 T byte_pos = pos;
181 T byte_len = len;
182 switch (unit_) {
183 case CharUnit::UTF8_CHAR:
184 OP_REQUIRES(
185 context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
186 errors::InvalidArgument("pos ", pos, " out of range for ",
187 "string at index ", i));
188 break;
189 case CharUnit::BYTE:
190 byte_pos = AdjustedPosIndex(byte_pos, in);
191 OP_REQUIRES(
192 context, FastBoundsCheck(byte_pos, in.size() + 1),
193 errors::InvalidArgument("pos ", pos, " out of range for ",
194 "string b'", in, "' at index ", i));
195 }
196 StringPiece sub_in = in.substr(byte_pos, byte_len);
197 output(i).assign(sub_in.data(), sub_in.size());
198 }
199 break;
200 }
201 case 2: {
202 // Reshape tensors according to BCast results
203 auto input = input_tensor.shaped<tstring, 2>(bcast.x_reshape());
204 auto output = output_tensor->shaped<tstring, 2>(bcast.result_shape());
205 auto pos_shaped = pos_tensor.shaped<T, 2>(bcast.y_reshape());
206 auto len_shaped = len_tensor.shaped<T, 2>(bcast.y_reshape());
207
208 // Allocate temporary buffer for broadcasted position tensor
209 Tensor pos_buffer;
210 OP_REQUIRES_OK(context,
211 context->allocate_temp(DataTypeToEnum<T>::v(),
212 output_shape, &pos_buffer));
213 typename TTypes<T, 2>::Tensor pos_bcast(
214 pos_buffer.shaped<T, 2>(bcast.result_shape()));
215 pos_bcast =
216 pos_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast()));
217
218 // Allocate temporary buffer for broadcasted length tensor
219 Tensor len_buffer;
220 OP_REQUIRES_OK(context,
221 context->allocate_temp(DataTypeToEnum<T>::v(),
222 output_shape, &len_buffer));
223 typename TTypes<T, 2>::Tensor len_bcast(
224 len_buffer.shaped<T, 2>(bcast.result_shape()));
225 len_bcast =
226 len_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast()));
227
228 // Iterate through broadcasted tensors and perform substr
229 for (int i = 0; i < output_shape.dim_size(0); ++i) {
230 for (int j = 0; j < output_shape.dim_size(1); ++j) {
231 StringPiece in(input(input.dimension(0) > 1 ? i : 0,
232 input.dimension(1) > 1 ? j : 0));
233 const T pos =
234 tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
235 const T len =
236 tensorflow::internal::SubtleMustCopy(len_bcast(i, j));
237 T byte_pos = pos;
238 T byte_len = len;
239 switch (unit_) {
240 case CharUnit::UTF8_CHAR:
241 OP_REQUIRES(
242 context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
243 errors::InvalidArgument("pos ", pos, " out of range for ",
244 "string at index ", i));
245 break;
246 case CharUnit::BYTE:
247 byte_pos = AdjustedPosIndex(byte_pos, in);
248 OP_REQUIRES(
249 context, FastBoundsCheck(byte_pos, in.size() + 1),
250 errors::InvalidArgument("pos ", pos, " out of range for ",
251 "string b'", in, "' at index (",
252 i, ", ", j, ")"));
253 }
254 StringPiece sub_in = in.substr(byte_pos, byte_len);
255 output(i, j).assign(sub_in.data(), sub_in.size());
256 }
257 }
258 break;
259 }
260 default: {
261 context->SetStatus(errors::Unimplemented(
262 "Substr broadcast not implemented for ", ndims, " dimensions"));
263 }
264 }
265 }
266 }
267
268 private:
269 // This adjusts the requested position. Note it does not perform any bound
270 // checks.
271 static inline T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
272 if (pos_requested < 0) {
273 return s.size() + pos_requested;
274 }
275 return pos_requested;
276 }
277
278 // Return true if successful; otherwise, return false if the `pos` argument
279 // is out of range in the string.
280 static inline bool UpdatePosAndLenForUtf8(const StringPiece in, T* pos,
281 T* len) {
282 if (*pos >= 0) {
283 return UpdatePositivePosAndLenForUtf8(in, *pos, *len, pos, len);
284 } else {
285 return UpdateNegativePosAndLenForUtf8(in, *pos, *len, pos, len);
286 }
287 }
288
289 static bool UpdatePositivePosAndLenForUtf8(const StringPiece in, const T pos,
290 const T len, T* char_pos,
291 T* char_len) {
292 *char_pos = 0;
293 // Determine byte position of the substring start.
294 if (!ForwardNUTF8CharPositions(in, pos, char_pos)) {
295 return false;
296 }
297 // Determine position of the end of the substring.
298 // The length will be capped at the end of the string, and we ignore whether
299 // the string had enough characters to handle it or not.
300 *char_len = *char_pos;
301 ForwardNUTF8CharPositions(in, len, char_len);
302 // The length in bytes is the position end of the substring less the start.
303 *char_len = *char_len - *char_pos;
304 return true;
305 }
306
307 // This function expects a negative position relative to the end of the
308 // string, but will update the character position to a positive number
309 // relative to the beginning of the string.
310 static bool UpdateNegativePosAndLenForUtf8(const StringPiece in, const T pos,
311 const T len, T* char_pos,
312 T* char_len) {
313 // Initially treat the length as position of the end of the substring.
314 *char_len = in.size();
315 // This is the number of character to skip from the end of the string to
316 // arrive at the position where the substring should end.
317 T utf8_chars_to_skip = -pos - len;
318 if (utf8_chars_to_skip < 0) {
319 utf8_chars_to_skip = 0;
320 }
321 // Find the byte position where the substring should end using the computed
322 // number of characters to skip.
323 if (!BackNUTF8CharPositions(in, utf8_chars_to_skip, char_len)) {
324 return false;
325 }
326 // Next, determine where the substring should begin. The number of chars to
327 // skip is the requested position minus the chars we've previously skipped.
328 *char_pos = *char_len;
329 if (!BackNUTF8CharPositions(in, -pos - utf8_chars_to_skip, char_pos)) {
330 return false;
331 }
332 // The length in bytes is the position end of the substring less the start.
333 *char_len = *char_len - *char_pos;
334 return true;
335 }
336
337 CharUnit unit_ = CharUnit::BYTE;
338};
339
340#define REGISTER_SUBSTR(type) \
341 REGISTER_KERNEL_BUILDER( \
342 Name("Substr").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
343 SubstrOp<type>);
344REGISTER_SUBSTR(int32);
345REGISTER_SUBSTR(int64_t);
346} // namespace tensorflow
347