1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <cstddef> |
17 | #include <cstdlib> |
18 | #include <string> |
19 | |
20 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
21 | #include "tensorflow/core/framework/bounds_check.h" |
22 | #include "tensorflow/core/framework/kernel_def_builder.h" |
23 | #include "tensorflow/core/framework/op.h" |
24 | #include "tensorflow/core/framework/op_kernel.h" |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/framework/tensor_shape.h" |
27 | #include "tensorflow/core/framework/tensor_types.h" |
28 | #include "tensorflow/core/framework/types.h" |
29 | #include "tensorflow/core/kernels/string_util.h" |
30 | #include "tensorflow/core/lib/core/errors.h" |
31 | #include "tensorflow/core/lib/core/stringpiece.h" |
32 | #include "tensorflow/core/platform/types.h" |
33 | #include "tensorflow/core/util/bcast.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | // Position/length can be 32 or 64-bit integers |
38 | template <typename T> |
39 | class SubstrOp : public OpKernel { |
40 | public: |
41 | explicit SubstrOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
42 | string unit; |
43 | OP_REQUIRES_OK(ctx, ctx->GetAttr("unit" , &unit)); |
44 | OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_)); |
45 | } |
46 | |
47 | void Compute(OpKernelContext* context) override { |
48 | // Get inputs |
49 | const Tensor& input_tensor = context->input(0); |
50 | const Tensor& pos_tensor = context->input(1); |
51 | const Tensor& len_tensor = context->input(2); |
52 | const TensorShape& input_shape = input_tensor.shape(); |
53 | const TensorShape& pos_shape = pos_tensor.shape(); |
54 | const TensorShape& len_shape = len_tensor.shape(); |
55 | OP_REQUIRES(context, (pos_shape == len_shape), |
56 | errors::InvalidArgument( |
57 | "pos and len should have the same shape, got: " , |
58 | pos_shape.DebugString(), " vs. " , len_shape.DebugString())); |
59 | |
60 | bool is_scalar = TensorShapeUtils::IsScalar(pos_shape); |
61 | |
62 | if (is_scalar || input_shape == pos_shape) { |
63 | // pos/len are either scalar or match the shape of input_tensor |
64 | // Do not need to do broadcasting |
65 | |
66 | // Reshape input |
67 | auto input = input_tensor.flat<tstring>(); |
68 | // Allocate output |
69 | Tensor* output_tensor = nullptr; |
70 | OP_REQUIRES_OK(context, |
71 | context->allocate_output("output" , input_tensor.shape(), |
72 | &output_tensor)); |
73 | auto output = output_tensor->flat<tstring>(); |
74 | if (is_scalar) { |
75 | // Perform Op with scalar pos/len |
76 | const T pos = |
77 | tensorflow::internal::SubtleMustCopy(pos_tensor.scalar<T>()()); |
78 | const T len = |
79 | tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()()); |
80 | for (size_t i = 0; i < input_tensor.NumElements(); ++i) { |
81 | StringPiece in(input(i)); |
82 | T byte_pos = pos; |
83 | T byte_len = len; |
84 | switch (unit_) { |
85 | case CharUnit::UTF8_CHAR: |
86 | OP_REQUIRES( |
87 | context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), |
88 | errors::InvalidArgument("pos " , pos, " out of range for " , |
89 | "string at index " , i)); |
90 | break; |
91 | case CharUnit::BYTE: |
92 | byte_pos = AdjustedPosIndex(byte_pos, in); |
93 | OP_REQUIRES( |
94 | context, FastBoundsCheck(byte_pos, in.size() + 1), |
95 | errors::InvalidArgument("pos " , pos, " out of range for " , |
96 | "string b'" , in, "' at index " , i)); |
97 | } |
98 | StringPiece sub_in = in.substr(byte_pos, byte_len); |
99 | output(i).assign(sub_in.data(), sub_in.size()); |
100 | } |
101 | } else { |
102 | // Perform Op element-wise with tensor pos/len |
103 | auto pos_flat = pos_tensor.flat<T>(); |
104 | auto len_flat = len_tensor.flat<T>(); |
105 | for (size_t i = 0; i < input_tensor.NumElements(); ++i) { |
106 | StringPiece in(input(i)); |
107 | const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i)); |
108 | const T len = tensorflow::internal::SubtleMustCopy(len_flat(i)); |
109 | T byte_pos = pos; |
110 | T byte_len = len; |
111 | switch (unit_) { |
112 | case CharUnit::UTF8_CHAR: |
113 | OP_REQUIRES( |
114 | context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), |
115 | errors::InvalidArgument("pos " , pos, " out of range for " , |
116 | "string at index " , i)); |
117 | break; |
118 | case CharUnit::BYTE: |
119 | byte_pos = AdjustedPosIndex(byte_pos, in); |
120 | OP_REQUIRES( |
121 | context, FastBoundsCheck(byte_pos, in.size() + 1), |
122 | errors::InvalidArgument("pos " , pos, " out of range for " , |
123 | "string b'" , in, "' at index " , i)); |
124 | } |
125 | StringPiece sub_in = in.substr(byte_pos, byte_len); |
126 | output(i).assign(sub_in.data(), sub_in.size()); |
127 | } |
128 | } |
129 | } else { |
130 | // Perform op with broadcasting |
131 | // TODO: Use ternary broadcasting for once available in Eigen. Current |
132 | // implementation iterates through broadcasted ops element-wise; |
133 | // this should be parallelized. |
134 | |
135 | // Create BCast helper with shape of input and pos/len |
136 | BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(pos_shape), |
137 | /*fewer_dims_optimization*/ false); |
138 | OP_REQUIRES(context, bcast.IsValid(), |
139 | errors::InvalidArgument( |
140 | "Incompatible shapes: " , input_shape.DebugString(), |
141 | " vs. " , pos_shape.DebugString())); |
142 | TensorShape output_shape = BCast::ToShape(bcast.result_shape()); |
143 | int ndims = output_shape.dims(); |
144 | Tensor* output_tensor = nullptr; |
145 | OP_REQUIRES_OK(context, context->allocate_output("output" , output_shape, |
146 | &output_tensor)); |
147 | switch (ndims) { |
148 | case 1: { |
149 | // Reshape tensors according to BCast results |
150 | auto input = input_tensor.shaped<tstring, 1>(bcast.x_reshape()); |
151 | auto output = output_tensor->shaped<tstring, 1>(bcast.result_shape()); |
152 | auto pos_shaped = pos_tensor.shaped<T, 1>(bcast.y_reshape()); |
153 | auto len_shaped = len_tensor.shaped<T, 1>(bcast.y_reshape()); |
154 | |
155 | // Allocate temporary buffer for broadcasted position tensor |
156 | Tensor pos_buffer; |
157 | OP_REQUIRES_OK(context, |
158 | context->allocate_temp(DataTypeToEnum<T>::v(), |
159 | output_shape, &pos_buffer)); |
160 | typename TTypes<T, 1>::Tensor pos_bcast( |
161 | pos_buffer.shaped<T, 1>(bcast.result_shape())); |
162 | pos_bcast = |
163 | pos_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast())); |
164 | |
165 | // Allocate temporary buffer for broadcasted length tensor |
166 | Tensor len_buffer; |
167 | OP_REQUIRES_OK(context, |
168 | context->allocate_temp(DataTypeToEnum<T>::v(), |
169 | output_shape, &len_buffer)); |
170 | typename TTypes<T, 1>::Tensor len_bcast( |
171 | len_buffer.shaped<T, 1>(bcast.result_shape())); |
172 | len_bcast = |
173 | len_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast())); |
174 | |
175 | // Iterate through broadcasted tensors and perform substr |
176 | for (int i = 0; i < output_shape.dim_size(0); ++i) { |
177 | StringPiece in(input(input.dimension(0) > 1 ? i : 0)); |
178 | const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i)); |
179 | const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i)); |
180 | T byte_pos = pos; |
181 | T byte_len = len; |
182 | switch (unit_) { |
183 | case CharUnit::UTF8_CHAR: |
184 | OP_REQUIRES( |
185 | context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), |
186 | errors::InvalidArgument("pos " , pos, " out of range for " , |
187 | "string at index " , i)); |
188 | break; |
189 | case CharUnit::BYTE: |
190 | byte_pos = AdjustedPosIndex(byte_pos, in); |
191 | OP_REQUIRES( |
192 | context, FastBoundsCheck(byte_pos, in.size() + 1), |
193 | errors::InvalidArgument("pos " , pos, " out of range for " , |
194 | "string b'" , in, "' at index " , i)); |
195 | } |
196 | StringPiece sub_in = in.substr(byte_pos, byte_len); |
197 | output(i).assign(sub_in.data(), sub_in.size()); |
198 | } |
199 | break; |
200 | } |
201 | case 2: { |
202 | // Reshape tensors according to BCast results |
203 | auto input = input_tensor.shaped<tstring, 2>(bcast.x_reshape()); |
204 | auto output = output_tensor->shaped<tstring, 2>(bcast.result_shape()); |
205 | auto pos_shaped = pos_tensor.shaped<T, 2>(bcast.y_reshape()); |
206 | auto len_shaped = len_tensor.shaped<T, 2>(bcast.y_reshape()); |
207 | |
208 | // Allocate temporary buffer for broadcasted position tensor |
209 | Tensor pos_buffer; |
210 | OP_REQUIRES_OK(context, |
211 | context->allocate_temp(DataTypeToEnum<T>::v(), |
212 | output_shape, &pos_buffer)); |
213 | typename TTypes<T, 2>::Tensor pos_bcast( |
214 | pos_buffer.shaped<T, 2>(bcast.result_shape())); |
215 | pos_bcast = |
216 | pos_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast())); |
217 | |
218 | // Allocate temporary buffer for broadcasted length tensor |
219 | Tensor len_buffer; |
220 | OP_REQUIRES_OK(context, |
221 | context->allocate_temp(DataTypeToEnum<T>::v(), |
222 | output_shape, &len_buffer)); |
223 | typename TTypes<T, 2>::Tensor len_bcast( |
224 | len_buffer.shaped<T, 2>(bcast.result_shape())); |
225 | len_bcast = |
226 | len_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast())); |
227 | |
228 | // Iterate through broadcasted tensors and perform substr |
229 | for (int i = 0; i < output_shape.dim_size(0); ++i) { |
230 | for (int j = 0; j < output_shape.dim_size(1); ++j) { |
231 | StringPiece in(input(input.dimension(0) > 1 ? i : 0, |
232 | input.dimension(1) > 1 ? j : 0)); |
233 | const T pos = |
234 | tensorflow::internal::SubtleMustCopy(pos_bcast(i, j)); |
235 | const T len = |
236 | tensorflow::internal::SubtleMustCopy(len_bcast(i, j)); |
237 | T byte_pos = pos; |
238 | T byte_len = len; |
239 | switch (unit_) { |
240 | case CharUnit::UTF8_CHAR: |
241 | OP_REQUIRES( |
242 | context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), |
243 | errors::InvalidArgument("pos " , pos, " out of range for " , |
244 | "string at index " , i)); |
245 | break; |
246 | case CharUnit::BYTE: |
247 | byte_pos = AdjustedPosIndex(byte_pos, in); |
248 | OP_REQUIRES( |
249 | context, FastBoundsCheck(byte_pos, in.size() + 1), |
250 | errors::InvalidArgument("pos " , pos, " out of range for " , |
251 | "string b'" , in, "' at index (" , |
252 | i, ", " , j, ")" )); |
253 | } |
254 | StringPiece sub_in = in.substr(byte_pos, byte_len); |
255 | output(i, j).assign(sub_in.data(), sub_in.size()); |
256 | } |
257 | } |
258 | break; |
259 | } |
260 | default: { |
261 | context->SetStatus(errors::Unimplemented( |
262 | "Substr broadcast not implemented for " , ndims, " dimensions" )); |
263 | } |
264 | } |
265 | } |
266 | } |
267 | |
268 | private: |
269 | // This adjusts the requested position. Note it does not perform any bound |
270 | // checks. |
271 | static inline T AdjustedPosIndex(const T pos_requested, const StringPiece s) { |
272 | if (pos_requested < 0) { |
273 | return s.size() + pos_requested; |
274 | } |
275 | return pos_requested; |
276 | } |
277 | |
278 | // Return true if successful; otherwise, return false if the `pos` argument |
279 | // is out of range in the string. |
280 | static inline bool UpdatePosAndLenForUtf8(const StringPiece in, T* pos, |
281 | T* len) { |
282 | if (*pos >= 0) { |
283 | return UpdatePositivePosAndLenForUtf8(in, *pos, *len, pos, len); |
284 | } else { |
285 | return UpdateNegativePosAndLenForUtf8(in, *pos, *len, pos, len); |
286 | } |
287 | } |
288 | |
289 | static bool UpdatePositivePosAndLenForUtf8(const StringPiece in, const T pos, |
290 | const T len, T* char_pos, |
291 | T* char_len) { |
292 | *char_pos = 0; |
293 | // Determine byte position of the substring start. |
294 | if (!ForwardNUTF8CharPositions(in, pos, char_pos)) { |
295 | return false; |
296 | } |
297 | // Determine position of the end of the substring. |
298 | // The length will be capped at the end of the string, and we ignore whether |
299 | // the string had enough characters to handle it or not. |
300 | *char_len = *char_pos; |
301 | ForwardNUTF8CharPositions(in, len, char_len); |
302 | // The length in bytes is the position end of the substring less the start. |
303 | *char_len = *char_len - *char_pos; |
304 | return true; |
305 | } |
306 | |
307 | // This function expects a negative position relative to the end of the |
308 | // string, but will update the character position to a positive number |
309 | // relative to the beginning of the string. |
310 | static bool UpdateNegativePosAndLenForUtf8(const StringPiece in, const T pos, |
311 | const T len, T* char_pos, |
312 | T* char_len) { |
313 | // Initially treat the length as position of the end of the substring. |
314 | *char_len = in.size(); |
315 | // This is the number of character to skip from the end of the string to |
316 | // arrive at the position where the substring should end. |
317 | T utf8_chars_to_skip = -pos - len; |
318 | if (utf8_chars_to_skip < 0) { |
319 | utf8_chars_to_skip = 0; |
320 | } |
321 | // Find the byte position where the substring should end using the computed |
322 | // number of characters to skip. |
323 | if (!BackNUTF8CharPositions(in, utf8_chars_to_skip, char_len)) { |
324 | return false; |
325 | } |
326 | // Next, determine where the substring should begin. The number of chars to |
327 | // skip is the requested position minus the chars we've previously skipped. |
328 | *char_pos = *char_len; |
329 | if (!BackNUTF8CharPositions(in, -pos - utf8_chars_to_skip, char_pos)) { |
330 | return false; |
331 | } |
332 | // The length in bytes is the position end of the substring less the start. |
333 | *char_len = *char_len - *char_pos; |
334 | return true; |
335 | } |
336 | |
337 | CharUnit unit_ = CharUnit::BYTE; |
338 | }; |
339 | |
340 | #define REGISTER_SUBSTR(type) \ |
341 | REGISTER_KERNEL_BUILDER( \ |
342 | Name("Substr").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
343 | SubstrOp<type>); |
344 | REGISTER_SUBSTR(int32); |
345 | REGISTER_SUBSTR(int64_t); |
346 | } // namespace tensorflow |
347 | |