1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <stdint.h> |
17 | |
18 | #include <cstddef> |
19 | #include <functional> |
20 | #include <memory> |
21 | #include <string> |
22 | #include <vector> |
23 | |
24 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
25 | #include "unicode/appendable.h" // from @icu |
26 | #include "unicode/schriter.h" // from @icu |
27 | #include "unicode/uchar.h" // from @icu |
28 | #include "unicode/ucnv.h" // from @icu |
29 | #include "unicode/ucnv_err.h" // from @icu |
30 | #include "unicode/umachine.h" // from @icu |
31 | #include "unicode/uniset.h" // from @icu |
32 | #include "unicode/unistr.h" // from @icu |
33 | #include "unicode/uset.h" // from @icu |
34 | #include "unicode/utf.h" // from @icu |
35 | #include "unicode/utypes.h" // from @icu |
36 | #include "tensorflow/core/framework/bounds_check.h" |
37 | #include "tensorflow/core/framework/kernel_def_builder.h" |
38 | #include "tensorflow/core/framework/op.h" |
39 | #include "tensorflow/core/framework/op_kernel.h" |
40 | #include "tensorflow/core/framework/register_types.h" |
41 | #include "tensorflow/core/framework/tensor.h" |
42 | #include "tensorflow/core/framework/tensor_shape.h" |
43 | #include "tensorflow/core/framework/tensor_types.h" |
44 | #include "tensorflow/core/framework/types.h" |
45 | #include "tensorflow/core/kernels/string_util.h" |
46 | #include "tensorflow/core/lib/core/errors.h" |
47 | #include "tensorflow/core/lib/core/status.h" |
48 | #include "tensorflow/core/lib/core/stringpiece.h" |
49 | #include "tensorflow/core/platform/types.h" |
50 | #include "tensorflow/core/util/bcast.h" |
51 | #include "tensorflow/core/util/ptr_util.h" |
52 | |
53 | namespace tensorflow { |
54 | namespace { |
55 | |
56 | void Encode(const UnicodeEncoding encoding, const icu::UnicodeString& in, |
57 | tstring* out) { |
58 | if (encoding == UnicodeEncoding::UTF8) { |
59 | out->clear(); |
60 | in.toUTF8String(*out); |
61 | } else if (encoding == UnicodeEncoding::UTF16BE) { |
62 | // TODO(gbillock): consider using the |
63 | // extract(char *dest, int32_t destCapacity, UConverter *cnv) |
64 | // for UTF16/32 |
65 | out->clear(); // subtle: must come before reserve() |
66 | out->reserve(2 * in.length() + 1); |
67 | const char16_t* buf = in.getBuffer(); |
68 | for (int i = 0; i < in.length(); ++i) { |
69 | // Emit big-endian encoding for UTF-16 always. |
70 | out->push_back((buf[i] & 0xFF00) >> 8); |
71 | out->push_back(buf[i] & 0x00FF); |
72 | } |
73 | } else if (encoding == UnicodeEncoding::UTF32BE) { |
74 | out->clear(); // subtle: must come before reserve() |
75 | out->reserve(4 * in.countChar32() + 1); |
76 | icu::StringCharacterIterator it(in); |
77 | UChar32 ch; |
78 | while (it.hasNext()) { |
79 | ch = it.next32PostInc(); |
80 | out->push_back((ch & 0xFF000000) >> 24); |
81 | out->push_back((ch & 0x00FF0000) >> 16); |
82 | out->push_back((ch & 0x0000FF00) >> 8); |
83 | out->push_back((ch & 0x000000FF)); |
84 | } |
85 | } |
86 | } |
87 | |
88 | // This error callback is only useful for finding illegal encoding errors when |
89 | // we want to be strict -- otherwise illegal encodings are replaced on read |
90 | // with 0xFFFD and signaled to the callback. |
91 | void unicode_error_callback(const void* context, UConverterToUnicodeArgs* args, |
92 | const char* codeUnits, int32_t length, |
93 | UConverterCallbackReason reason, |
94 | UErrorCode* pErrorCode) { |
95 | // Careful: this depends on setting up the context settings when the |
96 | // callback is registered. |
97 | bool* format_error = const_cast<bool*>(static_cast<const bool*>(context)); |
98 | |
99 | if (reason == UCNV_UNASSIGNED || reason == UCNV_ILLEGAL || |
100 | reason == UCNV_IRREGULAR) { |
101 | *format_error = true; |
102 | } |
103 | |
104 | // Side note: the default behavior in this case is that without a substitution |
105 | // made by the callback, the UConverter will signal an error to the iterator |
106 | // making the string iteration bail out. Instead, forward to the built-in |
107 | // substitution handler. |
108 | UCNV_TO_U_CALLBACK_SUBSTITUTE(nullptr, args, codeUnits, length, reason, |
109 | pErrorCode); |
110 | } |
111 | |
112 | // Iterates through a source string given the provided input UConverter specific |
113 | // to the encoding for that string. Calls a provided callback for each codepoint |
114 | // consumed. Provides the callback with the codepoint and the number of bytes |
115 | // consumed from the input string to produce it. If there are invalid encoding |
116 | // loci in the source string, they will be provided as a 0xFFFD codepoint to |
117 | // the callback, unless the "fail_on_formatting_error" arg is set, in which |
118 | // case the callback will be passed the signal that there is such an invalid |
119 | // encoding position. |
120 | // callback: function(UChar32 codepoint, int num_bytes_consumed_from_source_str, |
121 | // bool fatal_format_error) |
122 | void IterateUnicodeString(const string& str, UConverter* converter, |
123 | std::function<void(UChar32, int, bool)> callback) { |
124 | const char* source = str.data(); |
125 | const char* limit = str.data() + str.length(); |
126 | UErrorCode status = U_ZERO_ERROR; |
127 | |
128 | UConverterToUCallback oldAction = nullptr; |
129 | const void* oldContext = nullptr; |
130 | bool format_error = false; |
131 | |
132 | // Subtle. You can't make a function pointer from a std::function. :-( |
133 | // Instead, we pass the boolean pointer as the "context" object. |
134 | ucnv_setToUCallBack(converter, unicode_error_callback, &format_error, |
135 | &oldAction, &oldContext, &status); |
136 | if (U_FAILURE(status)) { |
137 | LOG(ERROR) << "Could not set unicode error callback on converter" ; |
138 | return; |
139 | } |
140 | |
141 | while (source < limit) { |
142 | const char* source_pre_fetch = source; |
143 | // Note: ucnv_getNextUChar returns 0xFFFD on an encoding error. |
144 | UChar32 next_char = ucnv_getNextUChar(converter, &source, limit, &status); |
145 | if (U_FAILURE(status)) { |
146 | source = limit; |
147 | } |
148 | int bytes_consumed = source - source_pre_fetch; |
149 | callback(next_char, bytes_consumed, format_error); |
150 | format_error = false; |
151 | } |
152 | |
153 | ucnv_setToUCallBack(converter, oldAction, oldContext, nullptr, nullptr, |
154 | &status); |
155 | } |
156 | |
157 | // Lifecycle wrapper for UConverter making it easier to use with thread_local. |
158 | // TODO(gbillock): Consider whether to use the higher-level convert API and |
159 | // create a specialized fast code path for UTF8. |
160 | class WrappedConverter { |
161 | public: |
162 | WrappedConverter() {} |
163 | |
164 | ~WrappedConverter() { |
165 | if (converter_) { |
166 | ucnv_close(converter_); |
167 | } |
168 | } |
169 | |
170 | void init(const string& name) { |
171 | if (converter_ && name == name_) { |
172 | // Note: this reset is not typically needed, but if not done, then in some |
173 | // cases the cached converter will maintain state of input endianness |
174 | // which isn't valid from input to input in every batched case. |
175 | ucnv_reset(converter_); |
176 | return; |
177 | } |
178 | |
179 | if (converter_) { |
180 | ucnv_close(converter_); |
181 | converter_ = nullptr; |
182 | name_ = "" ; |
183 | } |
184 | |
185 | UErrorCode status = U_ZERO_ERROR; |
186 | converter_ = ucnv_open(name.c_str(), &status); |
187 | if (U_FAILURE(status)) { |
188 | if (converter_) { |
189 | ucnv_close(converter_); |
190 | converter_ = nullptr; |
191 | } |
192 | } else { |
193 | name_ = name; |
194 | } |
195 | } |
196 | |
197 | UConverter* converter_ = nullptr; |
198 | string name_; |
199 | }; |
200 | |
201 | struct ErrorOptions { |
202 | UChar32 subst = 0xFFFD; |
203 | bool elide_replacement = false; |
204 | bool replace_control_chars = false; |
205 | bool error_on_malformatting = false; |
206 | }; |
207 | |
208 | Status GetErrorOptions(OpKernelConstruction* ctx, ErrorOptions* out) { |
209 | *out = ErrorOptions(); |
210 | |
211 | string error_policy; |
212 | TF_RETURN_IF_ERROR(ctx->GetAttr("errors" , &error_policy)); |
213 | |
214 | if (error_policy == "replace" ) { |
215 | out->elide_replacement = false; |
216 | } else if (error_policy == "ignore" ) { |
217 | out->elide_replacement = true; |
218 | } else if (error_policy == "strict" ) { |
219 | out->error_on_malformatting = true; |
220 | } else { |
221 | return errors::InvalidArgument( |
222 | "errors policy must be one of 'strict', 'replace', or 'ignore'" ); |
223 | } |
224 | |
225 | int32_t replacement_char; |
226 | TF_RETURN_IF_ERROR(ctx->GetAttr("replacement_char" , &replacement_char)); |
227 | |
228 | if (replacement_char >= UCHAR_MIN_VALUE && |
229 | replacement_char <= UCHAR_MAX_VALUE) { |
230 | out->subst = replacement_char; |
231 | } else { |
232 | return errors::InvalidArgument( |
233 | "replacement_char out of unicode codepoint range" ); |
234 | } |
235 | |
236 | if (ctx->HasAttr("replace_control_characters" )) { |
237 | TF_RETURN_IF_ERROR(ctx->GetAttr("replace_control_characters" , |
238 | &(out->replace_control_chars))); |
239 | } |
240 | |
241 | return OkStatus(); |
242 | } |
243 | |
244 | inline bool ShouldHandleFormatError(const ErrorOptions& error_options, |
245 | UChar32 ch, bool format_error) { |
246 | return ((error_options.replace_control_chars && ch <= 0x1F) || format_error); |
247 | } |
248 | |
249 | } // namespace |
250 | |
251 | class UnicodeTranscodeOp : public OpKernel { |
252 | public: |
253 | explicit UnicodeTranscodeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
254 | OP_REQUIRES_OK(ctx, GetErrorOptions(ctx, &error_options_)); |
255 | |
256 | string output_encoding; |
257 | OP_REQUIRES_OK(ctx, ctx->GetAttr("output_encoding" , &output_encoding)); |
258 | OP_REQUIRES_OK(ctx, |
259 | ParseUnicodeEncoding(output_encoding, &output_encoding_)); |
260 | |
261 | OP_REQUIRES_OK(ctx, ctx->GetAttr("input_encoding" , &input_encoding_)); |
262 | // Make a temporary UConverter to ensure it will create without error |
263 | // at execution time (and to warm any data caches the converter needs). |
264 | // This instance is not used. |
265 | std::unique_ptr<WrappedConverter> input_encoder = |
266 | std::make_unique<WrappedConverter>(); |
267 | input_encoder->init(input_encoding_); |
268 | OP_REQUIRES(ctx, input_encoder->converter_, |
269 | errors::InvalidArgument( |
270 | "Could not create converter for input encoding: " + |
271 | input_encoding_)); |
272 | } |
273 | |
274 | void Compute(OpKernelContext* ctx) override { |
275 | const Tensor* input_tensor; |
276 | OP_REQUIRES_OK(ctx, ctx->input("input" , &input_tensor)); |
277 | |
278 | static thread_local std::unique_ptr<WrappedConverter> input_encoder; |
279 | if (!input_encoder) { |
280 | input_encoder.reset(new WrappedConverter()); |
281 | } |
282 | input_encoder->init(input_encoding_); |
283 | OP_REQUIRES(ctx, input_encoder->converter_, |
284 | errors::InvalidArgument( |
285 | "Could not create converter for input encoding: " + |
286 | input_encoding_)); |
287 | |
288 | // Output may be forwardable from input, in which case work in-place. |
289 | Tensor* output_tensor; |
290 | std::unique_ptr<Tensor> maybe_forwarded = |
291 | ctx->forward_input(0 /*input_index*/, 0 /*output_index*/, |
292 | tensorflow::DT_STRING, input_tensor->shape(), |
293 | ctx->input_memory_type(0), ctx->input_alloc_attr(0)); |
294 | if (maybe_forwarded) { |
295 | output_tensor = maybe_forwarded.get(); |
296 | OP_REQUIRES_OK(ctx, ctx->set_output("output" , *output_tensor)); |
297 | } else { |
298 | OP_REQUIRES_OK(ctx, ctx->allocate_output("output" , input_tensor->shape(), |
299 | &output_tensor)); |
300 | output_tensor->flat<tstring>() = input_tensor->flat<tstring>(); |
301 | } |
302 | |
303 | auto output_flat = output_tensor->flat<tstring>(); |
304 | bool found_any_format_error = false; |
305 | for (size_t i = 0; i < output_flat.size(); ++i) { |
306 | Transcode(&(output_flat(i)), input_encoder->converter_, |
307 | &found_any_format_error); |
308 | } |
309 | if (error_options_.error_on_malformatting && found_any_format_error) { |
310 | ctx->CtxFailure( |
311 | errors::InvalidArgument("Invalid formatting on input string" )); |
312 | } |
313 | } |
314 | |
315 | private: |
316 | // Consume a codepoint from the input string and add it to the buffer. |
317 | // This function takes care of any replacement configuration on invalid or |
318 | // out-of-range inputs. |
319 | void TranslateCodepoints(icu::UnicodeString* s, bool* found_any_format_error, |
320 | UChar32 ch, int src_bytes, bool format_error) { |
321 | if (ShouldHandleFormatError(error_options_, ch, format_error)) { |
322 | *found_any_format_error = true; |
323 | if (error_options_.elide_replacement) { |
324 | return; |
325 | } else { |
326 | ch = error_options_.subst; |
327 | } |
328 | } |
329 | s->append(ch); |
330 | } |
331 | |
332 | // Transcode the string from input encoding to the output_encoding_. If |
333 | // non-valid characters are encountered, use the subst_/elide_replacement_ |
334 | // config to handle them. |
335 | void Transcode(tstring* s, UConverter* input_encoder, |
336 | bool* found_any_format_error) { |
337 | icu::UnicodeString source; |
338 | IterateUnicodeString( |
339 | *s, input_encoder, |
340 | std::bind(&UnicodeTranscodeOp::TranslateCodepoints, this, &source, |
341 | found_any_format_error, std::placeholders::_1, |
342 | std::placeholders::_2, std::placeholders::_3)); |
343 | |
344 | Encode(output_encoding_, source, s); |
345 | } |
346 | |
347 | string input_encoding_; |
348 | ErrorOptions error_options_; |
349 | UnicodeEncoding output_encoding_ = UnicodeEncoding::UTF8; |
350 | }; |
351 | |
352 | REGISTER_KERNEL_BUILDER(Name("UnicodeTranscode" ).Device(DEVICE_CPU), |
353 | UnicodeTranscodeOp); |
354 | |
355 | template <typename SPLITS_TYPE> |
356 | class UnicodeDecodeBaseOp : public OpKernel { |
357 | public: |
358 | explicit UnicodeDecodeBaseOp(OpKernelConstruction* ctx, bool generate_offsets) |
359 | : OpKernel(ctx), generate_offsets_(generate_offsets) { |
360 | OP_REQUIRES_OK(ctx, GetErrorOptions(ctx, &error_options_)); |
361 | OP_REQUIRES_OK(ctx, ctx->GetAttr("input_encoding" , &input_encoding_)); |
362 | // Make a temporary UConverter to ensure it will create without error |
363 | // at execution time (and to warm any data caches the converter needs). |
364 | // This instance is not used. |
365 | std::unique_ptr<WrappedConverter> input_encoder = |
366 | std::make_unique<WrappedConverter>(); |
367 | input_encoder->init(input_encoding_); |
368 | OP_REQUIRES(ctx, input_encoder->converter_, |
369 | errors::InvalidArgument( |
370 | "Could not create converter for input encoding: " + |
371 | input_encoding_)); |
372 | } |
373 | |
374 | void Decode(OpKernelContext* ctx, std::vector<UChar32>* char_values, |
375 | std::vector<SPLITS_TYPE>* offset_values, int* current_offset, |
376 | SPLITS_TYPE* next_row_split, UChar32 char_value, int char_length, |
377 | bool found_any_format_error) { |
378 | if (error_options_.error_on_malformatting && found_any_format_error) { |
379 | ctx->CtxFailure( |
380 | errors::InvalidArgument("Invalid formatting on input string" )); |
381 | } |
382 | UChar32 decoded_value = char_value; |
383 | if (ShouldHandleFormatError(error_options_, char_value, |
384 | found_any_format_error)) { |
385 | if (error_options_.elide_replacement && (offset_values != nullptr)) { |
386 | *current_offset += char_length; |
387 | return; |
388 | } else { |
389 | decoded_value = error_options_.subst; |
390 | } |
391 | } |
392 | |
393 | // Emit the char value. |
394 | char_values->push_back(decoded_value); |
395 | |
396 | // Emit the byte offset |
397 | if (offset_values != nullptr) { |
398 | offset_values->push_back(*current_offset); |
399 | *current_offset += char_length; |
400 | } |
401 | *next_row_split += 1; |
402 | } |
403 | |
404 | void Compute(OpKernelContext* ctx) override { |
405 | const Tensor* input_tensor; |
406 | OP_REQUIRES_OK(ctx, ctx->input("input" , &input_tensor)); |
407 | |
408 | // Go through all the strings in `input`. |
409 | const auto& input_vec = input_tensor->flat<tstring>(); |
410 | |
411 | std::unique_ptr<WrappedConverter> input_encoder = |
412 | std::make_unique<WrappedConverter>(); |
413 | input_encoder->init(input_encoding_); |
414 | OP_REQUIRES(ctx, input_encoder->converter_, |
415 | errors::InvalidArgument( |
416 | "Could not create converter for input encoding: " + |
417 | input_encoding_)); |
418 | |
419 | std::vector<UChar32> char_values; |
420 | std::vector<SPLITS_TYPE> offset_values; |
421 | |
422 | Tensor* output_row_splits; |
423 | OP_REQUIRES_OK(ctx, ctx->allocate_output("row_splits" , |
424 | {input_tensor->NumElements() + 1}, |
425 | &output_row_splits)); |
426 | auto out_row_splits = output_row_splits->vec<SPLITS_TYPE>(); |
427 | |
428 | int row_split_index = 0; |
429 | SPLITS_TYPE next_row_split = 0; |
430 | for (int i = 0; i < input_vec.size(); ++i) { |
431 | const string& input = input_vec(i); |
432 | // Convert input strings into unicode values. Output to a list of |
433 | // char_values, record row splits and char_to_byte_starts, which are all |
434 | // the fields needed to construct a RaggedTensor. |
435 | out_row_splits(row_split_index) = next_row_split; |
436 | row_split_index++; |
437 | int current_offset = 0; |
438 | IterateUnicodeString( |
439 | input, input_encoder->converter_, |
440 | std::bind(&UnicodeDecodeBaseOp::Decode, this, ctx, &char_values, |
441 | &offset_values, ¤t_offset, &next_row_split, |
442 | std::placeholders::_1, std::placeholders::_2, |
443 | std::placeholders::_3)); |
444 | } |
445 | out_row_splits(row_split_index) = next_row_split; |
446 | |
447 | Tensor* output_char_values; |
448 | OP_REQUIRES_OK( |
449 | ctx, ctx->allocate_output( |
450 | "char_values" , {static_cast<SPLITS_TYPE>(char_values.size())}, |
451 | &output_char_values)); |
452 | auto out_char_values = output_char_values->vec<int32>(); |
453 | if (generate_offsets_) { |
454 | DCHECK(offset_values.size() == char_values.size()); |
455 | Tensor* output_offset_values; |
456 | OP_REQUIRES_OK(ctx, ctx->allocate_output( |
457 | "char_to_byte_starts" , |
458 | {static_cast<SPLITS_TYPE>(offset_values.size())}, |
459 | &output_offset_values)); |
460 | auto out_offset_values = output_offset_values->vec<SPLITS_TYPE>(); |
461 | |
462 | // Load output tensors from intermediate value arrays. |
463 | for (int i = 0; i < char_values.size(); ++i) { |
464 | out_char_values(i) = static_cast<int32>(char_values[i]); |
465 | out_offset_values(i) = offset_values[i]; |
466 | } |
467 | } else { |
468 | for (int i = 0; i < char_values.size(); ++i) { |
469 | out_char_values(i) = static_cast<int32>(char_values[i]); |
470 | } |
471 | } |
472 | } |
473 | |
474 | private: |
475 | string input_encoding_; |
476 | ErrorOptions error_options_; |
477 | bool generate_offsets_ = false; |
478 | }; |
479 | |
480 | template <typename SPLITS_TYPE> |
481 | class UnicodeDecodeOp : public UnicodeDecodeBaseOp<SPLITS_TYPE> { |
482 | public: |
483 | explicit UnicodeDecodeOp(OpKernelConstruction* ctx) |
484 | : UnicodeDecodeBaseOp<SPLITS_TYPE>(ctx, false) {} |
485 | }; |
486 | |
487 | template <typename SPLITS_TYPE> |
488 | class UnicodeDecodeWithOffsetsOp : public UnicodeDecodeBaseOp<SPLITS_TYPE> { |
489 | public: |
490 | explicit UnicodeDecodeWithOffsetsOp(OpKernelConstruction* ctx) |
491 | : UnicodeDecodeBaseOp<SPLITS_TYPE>(ctx, true) {} |
492 | }; |
493 | |
494 | REGISTER_KERNEL_BUILDER( |
495 | Name("UnicodeDecode" ).Device(DEVICE_CPU).TypeConstraint<int64_t>("Tsplits" ), |
496 | UnicodeDecodeOp<int64_t>); |
497 | REGISTER_KERNEL_BUILDER(Name("UnicodeDecodeWithOffsets" ) |
498 | .Device(DEVICE_CPU) |
499 | .TypeConstraint<int64_t>("Tsplits" ), |
500 | UnicodeDecodeWithOffsetsOp<int64_t>); |
501 | REGISTER_KERNEL_BUILDER( |
502 | Name("UnicodeDecode" ).Device(DEVICE_CPU).TypeConstraint<int32>("Tsplits" ), |
503 | UnicodeDecodeOp<int32>); |
504 | REGISTER_KERNEL_BUILDER(Name("UnicodeDecodeWithOffsets" ) |
505 | .Device(DEVICE_CPU) |
506 | .TypeConstraint<int32>("Tsplits" ), |
507 | UnicodeDecodeWithOffsetsOp<int32>); |
508 | |
509 | template <typename SPLITS_TYPE> |
510 | class UnicodeEncodeOp : public OpKernel { |
511 | public: |
512 | explicit UnicodeEncodeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
513 | string encoding_tmp; |
514 | OP_REQUIRES_OK(ctx, ctx->GetAttr("output_encoding" , &encoding_tmp)); |
515 | OP_REQUIRES_OK(ctx, ParseUnicodeEncoding(encoding_tmp, &encoding_)); |
516 | OP_REQUIRES_OK(ctx, GetErrorOptions(ctx, &error_options_)); |
517 | } |
518 | |
519 | /** |
520 | * Encodes Unicode codepoints into the desired string representation. |
521 | * |
522 | * We lose a dimension while encoding, since a series of integer codepoints is |
523 | * encoded into a single string. |
524 | * |
525 | * This accepts two input tensors: a rank 1 tensor of code point values and |
526 | * a single rank 1 tensor of splits which determine where each string begins |
527 | * and ends from the provided code points. |
528 | */ |
529 | void Compute(OpKernelContext* context) override { |
530 | // Get inputs |
531 | const Tensor& input_tensor = context->input(0); |
532 | const auto input_tensor_flat = input_tensor.flat<int32>(); |
533 | const Tensor& input_splits = context->input(1); |
534 | const auto input_splits_flat = input_splits.flat<SPLITS_TYPE>(); |
535 | |
536 | OP_REQUIRES( |
537 | context, input_splits.NumElements() > 0, |
538 | errors::InvalidArgument("Input_splits should contain elements, but " |
539 | "given input_values has 0 elements" )); |
540 | // Operation will treat first argument in input_splits as if it were zero |
541 | // regardless of its actual value since splits should begin with zero and |
542 | // end with the length of the input values vector. |
543 | OP_REQUIRES( |
544 | context, input_splits_flat(0) == 0, |
545 | errors::InvalidArgument("First value in input_splits must be zero." )); |
546 | OP_REQUIRES(context, |
547 | input_splits_flat(input_splits_flat.size() - 1) == |
548 | input_tensor_flat.size(), |
549 | errors::InvalidArgument("Last value in input_splits must be " |
550 | "equal to length of input_tensor." )); |
551 | // Since we limit to a 2-D input (flat_values of rank 1 and a single splits |
552 | // tensor), our output dimension will be 1 with it's size equal to the |
553 | // number of splits (outer dimension or ragged tensor). |
554 | TensorShape output_shape({input_splits.dim_size(0) - 1}); |
555 | Tensor* output_tensor; |
556 | OP_REQUIRES_OK(context, context->allocate_output("output" , output_shape, |
557 | &output_tensor)); |
558 | auto output_tensor_flat = output_tensor->flat<tstring>(); |
559 | |
560 | // Use a single index over the flattened input values tensor. |
561 | int idx = 0; |
562 | // Loop through our split dimension to create a new string at each split. |
563 | for (int i = 1; i < input_splits_flat.size(); ++i) { |
564 | icu::UnicodeString unicode_string; |
565 | icu::UnicodeStringAppendable appendable_unicode_string(unicode_string); |
566 | OP_REQUIRES( |
567 | context, input_splits_flat(i - 1) <= input_splits_flat(i), |
568 | errors::InvalidArgument( |
569 | "Values in input_splits must be equal or in ascending order." )); |
570 | OP_REQUIRES( |
571 | context, input_splits_flat(i) <= input_tensor_flat.size(), |
572 | errors::InvalidArgument("Values in input_splits must be less than or " |
573 | "equal to input_tensor length." )); |
574 | for (; idx < input_splits_flat(i); ++idx) { |
575 | int32_t code_point = input_tensor_flat(idx); |
576 | // Check for invalid code point |
577 | if (!U_IS_UNICODE_CHAR(code_point)) { |
578 | if (error_options_.error_on_malformatting) { |
579 | context->CtxFailure(errors::InvalidArgument( |
580 | "Code point is out of range for Unicode, or a noncharacter." )); |
581 | return; |
582 | } else if (!error_options_.elide_replacement) { |
583 | code_point = error_options_.subst; |
584 | } |
585 | } |
586 | appendable_unicode_string.appendCodePoint(code_point); |
587 | } |
588 | // Encode our string and save in the output. |
589 | tstring result; |
590 | Encode(encoding_, unicode_string, &result); |
591 | output_tensor_flat(i - 1) = std::move(result); |
592 | } |
593 | } |
594 | |
595 | private: |
596 | UnicodeEncoding encoding_; |
597 | ErrorOptions error_options_; |
598 | }; |
599 | |
600 | REGISTER_KERNEL_BUILDER( |
601 | Name("UnicodeEncode" ).Device(DEVICE_CPU).TypeConstraint<int64_t>("Tsplits" ), |
602 | UnicodeEncodeOp<int64_t>); |
603 | REGISTER_KERNEL_BUILDER( |
604 | Name("UnicodeEncode" ).Device(DEVICE_CPU).TypeConstraint<int32>("Tsplits" ), |
605 | UnicodeEncodeOp<int32>); |
606 | |
607 | } // namespace tensorflow |
608 | |