1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <string>
17#include <vector>
18
19#include "absl/strings/str_split.h"
20#include "tensorflow/core/framework/common_shape_fns.h"
21#include "tensorflow/core/framework/op.h"
22#include "tensorflow/core/framework/shape_inference.h"
23#include "tensorflow/core/lib/core/errors.h"
24#include "tensorflow/core/lib/core/status.h"
25#include "tensorflow/core/lib/strings/strcat.h"
26#include "tensorflow/core/platform/types.h"
27
28namespace tensorflow {
29
30namespace shape_inference {
31class InferenceContext;
32} // namespace shape_inference
33
34using shape_inference::DimensionHandle;
35using shape_inference::InferenceContext;
36using shape_inference::ShapeHandle;
37
38REGISTER_OP("RegexReplace")
39 .Input("input: string")
40 .Input("pattern: string")
41 .Input("rewrite: string")
42 .Output("output: string")
43 .Attr("replace_global: bool = true")
44 .SetShapeFn([](InferenceContext* c) {
45 ShapeHandle unused;
46 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
47 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
48 c->set_output(0, c->input(0));
49 return OkStatus();
50 });
51
52REGISTER_OP("StaticRegexReplace")
53 .Input("input: string")
54 .Attr("pattern: string")
55 .Attr("rewrite: string")
56 .Output("output: string")
57 .Attr("replace_global: bool = true")
58 .SetShapeFn(shape_inference::UnchangedShape);
59
60REGISTER_OP("RegexFullMatch")
61 .Input("input: string")
62 .Input("pattern: string")
63 .Output("output: bool")
64 .SetShapeFn([](InferenceContext* c) {
65 ShapeHandle unused;
66 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
67 c->set_output(0, c->input(0));
68 return OkStatus();
69 });
70
71REGISTER_OP("StaticRegexFullMatch")
72 .Input("input: string")
73 .Attr("pattern: string")
74 .Output("output: bool")
75 .SetShapeFn(shape_inference::UnchangedShape);
76
77REGISTER_OP("StringToHashBucketFast")
78 .Input("input: string")
79 .Output("output: int64")
80 .Attr("num_buckets: int >= 1")
81 .SetShapeFn(shape_inference::UnchangedShape);
82
83REGISTER_OP("_TensorToHashBucketFast")
84 .Input("input: T")
85 .Output("output: int64")
86 .Attr("T: {int8, uint8, int16, uint16, int32, uint32, int64, uint64}")
87 .Attr("num_buckets: int >= 1")
88 .SetShapeFn(shape_inference::UnchangedShape)
89 .Doc(R"doc(
90Internal operation which is a composition of converting the tensor to a string
91tensor (AsString) and then calling hash functions (StringToHashBucketFast):
92reserved for internal use.
93
94Do not invoke this operator directly in Python. A fusion optimization is
95expected to create these operators.
96)doc");
97
98REGISTER_OP("StringToHashBucketStrong")
99 .Input("input: string")
100 .Output("output: int64")
101 .Attr("num_buckets: int >= 1")
102 .Attr("key: list(int)")
103 .SetShapeFn(shape_inference::UnchangedShape);
104
105REGISTER_OP("StringToHashBucket")
106 .Input("string_tensor: string")
107 .Output("output: int64")
108 .Attr("num_buckets: int >= 1")
109 .SetShapeFn(shape_inference::UnchangedShape);
110
111REGISTER_OP("ReduceJoin")
112 .Input("inputs: string")
113 .Input("reduction_indices: int32")
114 .Attr("keep_dims: bool = false")
115 .Attr("separator: string = ''")
116 .Output("output: string")
117 .SetShapeFn(shape_inference::ReductionShape);
118
119REGISTER_OP("UnsortedSegmentJoin")
120 .Input("inputs: string")
121 .Input("segment_ids: Tindices")
122 .Input("num_segments: Tnumsegments")
123 .Attr("separator: string = ''")
124 .Attr("Tindices: {int32,int64}")
125 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
126 .Output("output: string")
127 .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
128
129REGISTER_OP("AsString")
130 .Input("input: T")
131 .Output("output: string")
132 .Attr("T: {realnumbertype, complex64, complex128, bool, variant}")
133 .Attr("precision: int = -1")
134 .Attr("scientific: bool = false")
135 .Attr("shortest: bool = false")
136 .Attr("width: int = -1")
137 .Attr("fill: string = ''")
138 .SetShapeFn(shape_inference::UnchangedShape);
139
140REGISTER_OP("StringFormat")
141 .Input("inputs: T")
142 .Output("output: string")
143 .Attr("T: list(type) >= 0")
144 .Attr("template: string = '%s'")
145 .Attr("placeholder: string = '%s'")
146 .Attr("summarize: int = 3")
147 .SetShapeFn([](InferenceContext* c) {
148 string template_;
149 string placeholder;
150 TF_RETURN_IF_ERROR(c->GetAttr("template", &template_));
151 TF_RETURN_IF_ERROR(c->GetAttr("placeholder", &placeholder));
152
153 std::vector<std::string> split_template;
154 split_template = absl::StrSplit(template_, placeholder);
155 int64_t num_placeholders = split_template.size() - 1;
156 if (c->num_inputs() != num_placeholders) {
157 return errors::InvalidArgument(strings::StrCat(
158 "num placeholders in template and num inputs must match: ",
159 num_placeholders, " vs. ", c->num_inputs()));
160 }
161
162 c->set_output(0, c->Scalar());
163 return OkStatus();
164 });
165
166REGISTER_OP("StringJoin")
167 .Input("inputs: N * string")
168 .Attr("N: int")
169 .Attr("separator: string = ''")
170 .Output("output: string")
171 .SetShapeFn([](InferenceContext* c) {
172 // If all inputs are scalars, then return a scalar.
173 bool all_scalar = true;
174 for (int i = 0; i < c->num_inputs(); ++i) {
175 if (c->Rank(c->input(i)) != 0) all_scalar = false;
176 }
177 if (all_scalar) {
178 c->set_output(0, c->Scalar());
179 return OkStatus();
180 }
181
182 // At least one input is unknown or a scalar.
183 // Merge the non-scalars to find the output shape.
184 // Don't merge inputs with unknown rank, as they can actually be scalars
185 // or the output shape.
186 ShapeHandle out = c->UnknownShape();
187 for (int i = 0; i < c->num_inputs(); ++i) {
188 if (c->RankKnown(c->input(i)) && c->Rank(c->input(i)) != 0) {
189 TF_RETURN_IF_ERROR(c->Merge(out, c->input(i), &out));
190 }
191 }
192 c->set_output(0, out);
193 return OkStatus();
194 });
195
196REGISTER_OP("StringSplit")
197 .Input("input: string")
198 .Input("delimiter: string")
199 .Output("indices: int64")
200 .Output("values: string")
201 .Output("shape: int64")
202 .Attr("skip_empty: bool = true")
203 .SetShapeFn([](InferenceContext* c) {
204 ShapeHandle unused;
205 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
206 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
207
208 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2));
209 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
210 c->set_output(2, c->Vector(2));
211 return OkStatus();
212 });
213
214REGISTER_OP("StringSplitV2")
215 .Input("input: string")
216 .Input("sep: string")
217 .Output("indices: int64")
218 .Output("values: string")
219 .Output("shape: int64")
220 .Attr("maxsplit: int = -1")
221 .SetShapeFn([](InferenceContext* c) {
222 ShapeHandle unused;
223 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
224 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
225
226 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2));
227 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
228 c->set_output(2, c->Vector(2));
229 return OkStatus();
230 });
231
232REGISTER_OP("StringLower")
233 .Input("input: string")
234 .Output("output: string")
235 .Attr("encoding: string =''")
236 .SetShapeFn(shape_inference::UnchangedShape);
237
238REGISTER_OP("StringUpper")
239 .Input("input: string")
240 .Output("output: string")
241 .Attr("encoding: string =''")
242 .SetShapeFn(shape_inference::UnchangedShape);
243
244REGISTER_OP("StringStrip")
245 .Input("input: string")
246 .Output("output: string")
247 .SetShapeFn(shape_inference::UnchangedShape);
248
249REGISTER_OP("StringLength")
250 .Input("input: string")
251 .Output("output: int32")
252 .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
253 .SetShapeFn(shape_inference::UnchangedShape);
254
255REGISTER_OP("EncodeBase64")
256 .Input("input: string")
257 .Output("output: string")
258 .Attr("pad: bool = false")
259 .SetShapeFn(shape_inference::UnchangedShape);
260
261REGISTER_OP("DecodeBase64")
262 .Input("input: string")
263 .Output("output: string")
264 .SetShapeFn(shape_inference::UnchangedShape);
265
266REGISTER_OP("Substr")
267 .Input("input: string")
268 .Input("pos: T")
269 .Input("len: T")
270 .Output("output: string")
271 .Attr("T: {int32, int64}")
272 .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
273 .SetShapeFn([](InferenceContext* c) {
274 ShapeHandle pos_shape = c->input(1);
275 ShapeHandle len_shape = c->input(2);
276 ShapeHandle unused;
277 // If len rank is known, check that pos and len have the same rank
278 if (c->RankKnown(len_shape)) {
279 TF_RETURN_IF_ERROR(c->WithRank(pos_shape, c->Rank(len_shape), &unused));
280 }
281 // Check that dimensions are equal
282 for (int32_t i = 0; i < c->Rank(pos_shape); ++i) {
283 DimensionHandle pos_dim = c->Dim(pos_shape, i);
284 DimensionHandle len_dim = c->Dim(len_shape, i);
285 if (c->Value(pos_dim) != c->Value(len_dim)) {
286 return errors::InvalidArgument(
287 "pos and len shapes must match: ", c->DebugString(pos_shape),
288 " vs. ", c->DebugString(len_shape));
289 }
290 }
291 // c->input(0) is the ShapeHandle to input strings
292 // BroadcastBinaryOpShapeFn infers shape from c->input(0) and c->input(1).
293 return shape_inference::BroadcastBinaryOpShapeFn(c);
294 });
295
296REGISTER_OP("UnicodeScript")
297 .Input("input: int32")
298 .Output("output: int32")
299 .SetShapeFn(shape_inference::UnchangedShape);
300
301REGISTER_OP("UnicodeEncode")
302 .Input("input_values: int32")
303 .Input("input_splits: Tsplits")
304 .Attr("errors: {'ignore', 'replace', 'strict'} = 'replace'")
305 .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}")
306 .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char
307 .Attr("Tsplits: {int32, int64} = DT_INT64")
308 .Output("output: string")
309 .SetShapeFn([](InferenceContext* c) {
310 // Check rank of inner values
311 ShapeHandle input_inner_values_shape = c->input(0);
312 ShapeHandle unused;
313 TF_RETURN_IF_ERROR(c->WithRank(input_inner_values_shape, 1, &unused));
314
315 // Check rank of input_splits
316 ShapeHandle splits_shape = c->input(1);
317 TF_RETURN_IF_ERROR(c->WithRank(splits_shape, 1, &unused));
318
319 // Output shape is a 1-D tensor with size equal to number of splits.
320 std::vector<DimensionHandle> dims(1);
321 TF_RETURN_IF_ERROR(c->Subtract(c->Dim(splits_shape, 0), 1, &dims[0]));
322 c->set_output(0, c->MakeShape(dims));
323
324 return OkStatus();
325 });
326
327REGISTER_OP("UnicodeTranscode")
328 .Input("input: string")
329 .Output("output: string")
330 .Attr("input_encoding: string")
331 .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}")
332 .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'")
333 .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char
334 .Attr("replace_control_characters: bool = false")
335 .SetShapeFn(shape_inference::UnchangedShape);
336
337REGISTER_OP("UnicodeDecode")
338 .Input("input: string")
339 .Output("row_splits: Tsplits")
340 .Output("char_values: int32")
341 .Attr("input_encoding: string")
342 .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'")
343 .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char
344 .Attr("replace_control_characters: bool = false")
345 .Attr("Tsplits: {int32, int64} = DT_INT64")
346 .SetShapeFn([](InferenceContext* c) {
347 // row_splits.shape == [input.size() + 1]
348 DimensionHandle num_row_splits;
349 DimensionHandle input_size = c->NumElements(c->input(0));
350 TF_RETURN_IF_ERROR(c->Add(input_size, 1, &num_row_splits));
351 c->set_output(0, c->Vector(num_row_splits));
352
353 // char_values.shape == [num_chars]
354 DimensionHandle num_chars = c->UnknownDim();
355 c->set_output(1, c->Vector(num_chars));
356 return OkStatus();
357 });
358
359REGISTER_OP("UnicodeDecodeWithOffsets")
360 .Input("input: string")
361 .Output("row_splits: Tsplits")
362 .Output("char_values: int32")
363 .Output("char_to_byte_starts: int64")
364 .Attr("input_encoding: string")
365 .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'")
366 .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char
367 .Attr("replace_control_characters: bool = false")
368 .Attr("Tsplits: {int32, int64} = DT_INT64")
369 .SetShapeFn([](InferenceContext* c) {
370 // row_splits.shape == [input.size() + 1]
371 DimensionHandle num_row_splits;
372 DimensionHandle input_size = c->NumElements(c->input(0));
373 TF_RETURN_IF_ERROR(c->Add(input_size, 1, &num_row_splits));
374 c->set_output(0, c->Vector(num_row_splits));
375
376 // char_values.shape == offset_values.shape == [num_chars]
377 DimensionHandle num_chars = c->UnknownDim();
378 c->set_output(1, c->Vector(num_chars));
379 c->set_output(2, c->Vector(num_chars));
380 return OkStatus();
381 });
382
383REGISTER_OP("StringNGrams")
384 .Attr("separator: string")
385 .Attr("ngram_widths: list(int) >= 0")
386 .Attr("left_pad: string")
387 .Attr("right_pad: string")
388 .Attr("pad_width: int")
389 .Attr("preserve_short_sequences: bool")
390 .Attr("Tsplits: {int32, int64} = DT_INT64")
391 .Input("data: string")
392 .Input("data_splits: Tsplits")
393 .Output("ngrams: string")
394 .Output("ngrams_splits: Tsplits")
395 .SetShapeFn([](InferenceContext* c) {
396 c->set_output(0, c->UnknownShapeOfRank(1));
397 ShapeHandle data = c->input(0);
398 TF_RETURN_IF_ERROR(c->WithRank(data, 1, &data));
399 ShapeHandle data_splits = c->input(1);
400 TF_RETURN_IF_ERROR(c->WithRank(data_splits, 1, &data_splits));
401 c->set_output(1, data_splits);
402 return OkStatus();
403 });
404
405} // namespace tensorflow
406