1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <string> |
17 | #include <vector> |
18 | |
19 | #include "absl/strings/str_split.h" |
20 | #include "tensorflow/core/framework/common_shape_fns.h" |
21 | #include "tensorflow/core/framework/op.h" |
22 | #include "tensorflow/core/framework/shape_inference.h" |
23 | #include "tensorflow/core/lib/core/errors.h" |
24 | #include "tensorflow/core/lib/core/status.h" |
25 | #include "tensorflow/core/lib/strings/strcat.h" |
26 | #include "tensorflow/core/platform/types.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | namespace shape_inference { |
31 | class InferenceContext; |
32 | } // namespace shape_inference |
33 | |
34 | using shape_inference::DimensionHandle; |
35 | using shape_inference::InferenceContext; |
36 | using shape_inference::ShapeHandle; |
37 | |
38 | REGISTER_OP("RegexReplace" ) |
39 | .Input("input: string" ) |
40 | .Input("pattern: string" ) |
41 | .Input("rewrite: string" ) |
42 | .Output("output: string" ) |
43 | .Attr("replace_global: bool = true" ) |
44 | .SetShapeFn([](InferenceContext* c) { |
45 | ShapeHandle unused; |
46 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
47 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
48 | c->set_output(0, c->input(0)); |
49 | return OkStatus(); |
50 | }); |
51 | |
52 | REGISTER_OP("StaticRegexReplace" ) |
53 | .Input("input: string" ) |
54 | .Attr("pattern: string" ) |
55 | .Attr("rewrite: string" ) |
56 | .Output("output: string" ) |
57 | .Attr("replace_global: bool = true" ) |
58 | .SetShapeFn(shape_inference::UnchangedShape); |
59 | |
60 | REGISTER_OP("RegexFullMatch" ) |
61 | .Input("input: string" ) |
62 | .Input("pattern: string" ) |
63 | .Output("output: bool" ) |
64 | .SetShapeFn([](InferenceContext* c) { |
65 | ShapeHandle unused; |
66 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
67 | c->set_output(0, c->input(0)); |
68 | return OkStatus(); |
69 | }); |
70 | |
71 | REGISTER_OP("StaticRegexFullMatch" ) |
72 | .Input("input: string" ) |
73 | .Attr("pattern: string" ) |
74 | .Output("output: bool" ) |
75 | .SetShapeFn(shape_inference::UnchangedShape); |
76 | |
77 | REGISTER_OP("StringToHashBucketFast" ) |
78 | .Input("input: string" ) |
79 | .Output("output: int64" ) |
80 | .Attr("num_buckets: int >= 1" ) |
81 | .SetShapeFn(shape_inference::UnchangedShape); |
82 | |
83 | REGISTER_OP("_TensorToHashBucketFast" ) |
84 | .Input("input: T" ) |
85 | .Output("output: int64" ) |
86 | .Attr("T: {int8, uint8, int16, uint16, int32, uint32, int64, uint64}" ) |
87 | .Attr("num_buckets: int >= 1" ) |
88 | .SetShapeFn(shape_inference::UnchangedShape) |
89 | .Doc(R"doc( |
90 | Internal operation which is a composition of converting the tensor to a string |
91 | tensor (AsString) and then calling hash functions (StringToHashBucketFast): |
92 | reserved for internal use. |
93 | |
94 | Do not invoke this operator directly in Python. A fusion optimization is |
95 | expected to create these operators. |
96 | )doc" ); |
97 | |
98 | REGISTER_OP("StringToHashBucketStrong" ) |
99 | .Input("input: string" ) |
100 | .Output("output: int64" ) |
101 | .Attr("num_buckets: int >= 1" ) |
102 | .Attr("key: list(int)" ) |
103 | .SetShapeFn(shape_inference::UnchangedShape); |
104 | |
105 | REGISTER_OP("StringToHashBucket" ) |
106 | .Input("string_tensor: string" ) |
107 | .Output("output: int64" ) |
108 | .Attr("num_buckets: int >= 1" ) |
109 | .SetShapeFn(shape_inference::UnchangedShape); |
110 | |
111 | REGISTER_OP("ReduceJoin" ) |
112 | .Input("inputs: string" ) |
113 | .Input("reduction_indices: int32" ) |
114 | .Attr("keep_dims: bool = false" ) |
115 | .Attr("separator: string = ''" ) |
116 | .Output("output: string" ) |
117 | .SetShapeFn(shape_inference::ReductionShape); |
118 | |
119 | REGISTER_OP("UnsortedSegmentJoin" ) |
120 | .Input("inputs: string" ) |
121 | .Input("segment_ids: Tindices" ) |
122 | .Input("num_segments: Tnumsegments" ) |
123 | .Attr("separator: string = ''" ) |
124 | .Attr("Tindices: {int32,int64}" ) |
125 | .Attr("Tnumsegments: {int32,int64} = DT_INT32" ) |
126 | .Output("output: string" ) |
127 | .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn); |
128 | |
129 | REGISTER_OP("AsString" ) |
130 | .Input("input: T" ) |
131 | .Output("output: string" ) |
132 | .Attr("T: {realnumbertype, complex64, complex128, bool, variant}" ) |
133 | .Attr("precision: int = -1" ) |
134 | .Attr("scientific: bool = false" ) |
135 | .Attr("shortest: bool = false" ) |
136 | .Attr("width: int = -1" ) |
137 | .Attr("fill: string = ''" ) |
138 | .SetShapeFn(shape_inference::UnchangedShape); |
139 | |
140 | REGISTER_OP("StringFormat" ) |
141 | .Input("inputs: T" ) |
142 | .Output("output: string" ) |
143 | .Attr("T: list(type) >= 0" ) |
144 | .Attr("template: string = '%s'" ) |
145 | .Attr("placeholder: string = '%s'" ) |
146 | .Attr("summarize: int = 3" ) |
147 | .SetShapeFn([](InferenceContext* c) { |
148 | string template_; |
149 | string placeholder; |
150 | TF_RETURN_IF_ERROR(c->GetAttr("template" , &template_)); |
151 | TF_RETURN_IF_ERROR(c->GetAttr("placeholder" , &placeholder)); |
152 | |
153 | std::vector<std::string> split_template; |
154 | split_template = absl::StrSplit(template_, placeholder); |
155 | int64_t num_placeholders = split_template.size() - 1; |
156 | if (c->num_inputs() != num_placeholders) { |
157 | return errors::InvalidArgument(strings::StrCat( |
158 | "num placeholders in template and num inputs must match: " , |
159 | num_placeholders, " vs. " , c->num_inputs())); |
160 | } |
161 | |
162 | c->set_output(0, c->Scalar()); |
163 | return OkStatus(); |
164 | }); |
165 | |
166 | REGISTER_OP("StringJoin" ) |
167 | .Input("inputs: N * string" ) |
168 | .Attr("N: int" ) |
169 | .Attr("separator: string = ''" ) |
170 | .Output("output: string" ) |
171 | .SetShapeFn([](InferenceContext* c) { |
172 | // If all inputs are scalars, then return a scalar. |
173 | bool all_scalar = true; |
174 | for (int i = 0; i < c->num_inputs(); ++i) { |
175 | if (c->Rank(c->input(i)) != 0) all_scalar = false; |
176 | } |
177 | if (all_scalar) { |
178 | c->set_output(0, c->Scalar()); |
179 | return OkStatus(); |
180 | } |
181 | |
182 | // At least one input is unknown or a scalar. |
183 | // Merge the non-scalars to find the output shape. |
184 | // Don't merge inputs with unknown rank, as they can actually be scalars |
185 | // or the output shape. |
186 | ShapeHandle out = c->UnknownShape(); |
187 | for (int i = 0; i < c->num_inputs(); ++i) { |
188 | if (c->RankKnown(c->input(i)) && c->Rank(c->input(i)) != 0) { |
189 | TF_RETURN_IF_ERROR(c->Merge(out, c->input(i), &out)); |
190 | } |
191 | } |
192 | c->set_output(0, out); |
193 | return OkStatus(); |
194 | }); |
195 | |
196 | REGISTER_OP("StringSplit" ) |
197 | .Input("input: string" ) |
198 | .Input("delimiter: string" ) |
199 | .Output("indices: int64" ) |
200 | .Output("values: string" ) |
201 | .Output("shape: int64" ) |
202 | .Attr("skip_empty: bool = true" ) |
203 | .SetShapeFn([](InferenceContext* c) { |
204 | ShapeHandle unused; |
205 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); |
206 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
207 | |
208 | c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2)); |
209 | c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); |
210 | c->set_output(2, c->Vector(2)); |
211 | return OkStatus(); |
212 | }); |
213 | |
214 | REGISTER_OP("StringSplitV2" ) |
215 | .Input("input: string" ) |
216 | .Input("sep: string" ) |
217 | .Output("indices: int64" ) |
218 | .Output("values: string" ) |
219 | .Output("shape: int64" ) |
220 | .Attr("maxsplit: int = -1" ) |
221 | .SetShapeFn([](InferenceContext* c) { |
222 | ShapeHandle unused; |
223 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); |
224 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
225 | |
226 | c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2)); |
227 | c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); |
228 | c->set_output(2, c->Vector(2)); |
229 | return OkStatus(); |
230 | }); |
231 | |
232 | REGISTER_OP("StringLower" ) |
233 | .Input("input: string" ) |
234 | .Output("output: string" ) |
235 | .Attr("encoding: string =''" ) |
236 | .SetShapeFn(shape_inference::UnchangedShape); |
237 | |
238 | REGISTER_OP("StringUpper" ) |
239 | .Input("input: string" ) |
240 | .Output("output: string" ) |
241 | .Attr("encoding: string =''" ) |
242 | .SetShapeFn(shape_inference::UnchangedShape); |
243 | |
244 | REGISTER_OP("StringStrip" ) |
245 | .Input("input: string" ) |
246 | .Output("output: string" ) |
247 | .SetShapeFn(shape_inference::UnchangedShape); |
248 | |
249 | REGISTER_OP("StringLength" ) |
250 | .Input("input: string" ) |
251 | .Output("output: int32" ) |
252 | .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'" ) |
253 | .SetShapeFn(shape_inference::UnchangedShape); |
254 | |
255 | REGISTER_OP("EncodeBase64" ) |
256 | .Input("input: string" ) |
257 | .Output("output: string" ) |
258 | .Attr("pad: bool = false" ) |
259 | .SetShapeFn(shape_inference::UnchangedShape); |
260 | |
261 | REGISTER_OP("DecodeBase64" ) |
262 | .Input("input: string" ) |
263 | .Output("output: string" ) |
264 | .SetShapeFn(shape_inference::UnchangedShape); |
265 | |
266 | REGISTER_OP("Substr" ) |
267 | .Input("input: string" ) |
268 | .Input("pos: T" ) |
269 | .Input("len: T" ) |
270 | .Output("output: string" ) |
271 | .Attr("T: {int32, int64}" ) |
272 | .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'" ) |
273 | .SetShapeFn([](InferenceContext* c) { |
274 | ShapeHandle pos_shape = c->input(1); |
275 | ShapeHandle len_shape = c->input(2); |
276 | ShapeHandle unused; |
277 | // If len rank is known, check that pos and len have the same rank |
278 | if (c->RankKnown(len_shape)) { |
279 | TF_RETURN_IF_ERROR(c->WithRank(pos_shape, c->Rank(len_shape), &unused)); |
280 | } |
281 | // Check that dimensions are equal |
282 | for (int32_t i = 0; i < c->Rank(pos_shape); ++i) { |
283 | DimensionHandle pos_dim = c->Dim(pos_shape, i); |
284 | DimensionHandle len_dim = c->Dim(len_shape, i); |
285 | if (c->Value(pos_dim) != c->Value(len_dim)) { |
286 | return errors::InvalidArgument( |
287 | "pos and len shapes must match: " , c->DebugString(pos_shape), |
288 | " vs. " , c->DebugString(len_shape)); |
289 | } |
290 | } |
291 | // c->input(0) is the ShapeHandle to input strings |
292 | // BroadcastBinaryOpShapeFn infers shape from c->input(0) and c->input(1). |
293 | return shape_inference::BroadcastBinaryOpShapeFn(c); |
294 | }); |
295 | |
296 | REGISTER_OP("UnicodeScript" ) |
297 | .Input("input: int32" ) |
298 | .Output("output: int32" ) |
299 | .SetShapeFn(shape_inference::UnchangedShape); |
300 | |
301 | REGISTER_OP("UnicodeEncode" ) |
302 | .Input("input_values: int32" ) |
303 | .Input("input_splits: Tsplits" ) |
304 | .Attr("errors: {'ignore', 'replace', 'strict'} = 'replace'" ) |
305 | .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}" ) |
306 | .Attr("replacement_char: int = 65533" ) // 0xFFFD unicode replacement char |
307 | .Attr("Tsplits: {int32, int64} = DT_INT64" ) |
308 | .Output("output: string" ) |
309 | .SetShapeFn([](InferenceContext* c) { |
310 | // Check rank of inner values |
311 | ShapeHandle input_inner_values_shape = c->input(0); |
312 | ShapeHandle unused; |
313 | TF_RETURN_IF_ERROR(c->WithRank(input_inner_values_shape, 1, &unused)); |
314 | |
315 | // Check rank of input_splits |
316 | ShapeHandle splits_shape = c->input(1); |
317 | TF_RETURN_IF_ERROR(c->WithRank(splits_shape, 1, &unused)); |
318 | |
319 | // Output shape is a 1-D tensor with size equal to number of splits. |
320 | std::vector<DimensionHandle> dims(1); |
321 | TF_RETURN_IF_ERROR(c->Subtract(c->Dim(splits_shape, 0), 1, &dims[0])); |
322 | c->set_output(0, c->MakeShape(dims)); |
323 | |
324 | return OkStatus(); |
325 | }); |
326 | |
327 | REGISTER_OP("UnicodeTranscode" ) |
328 | .Input("input: string" ) |
329 | .Output("output: string" ) |
330 | .Attr("input_encoding: string" ) |
331 | .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}" ) |
332 | .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'" ) |
333 | .Attr("replacement_char: int = 65533" ) // 0xFFFD unicode replacement char |
334 | .Attr("replace_control_characters: bool = false" ) |
335 | .SetShapeFn(shape_inference::UnchangedShape); |
336 | |
337 | REGISTER_OP("UnicodeDecode" ) |
338 | .Input("input: string" ) |
339 | .Output("row_splits: Tsplits" ) |
340 | .Output("char_values: int32" ) |
341 | .Attr("input_encoding: string" ) |
342 | .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'" ) |
343 | .Attr("replacement_char: int = 65533" ) // 0xFFFD unicode replacement char |
344 | .Attr("replace_control_characters: bool = false" ) |
345 | .Attr("Tsplits: {int32, int64} = DT_INT64" ) |
346 | .SetShapeFn([](InferenceContext* c) { |
347 | // row_splits.shape == [input.size() + 1] |
348 | DimensionHandle num_row_splits; |
349 | DimensionHandle input_size = c->NumElements(c->input(0)); |
350 | TF_RETURN_IF_ERROR(c->Add(input_size, 1, &num_row_splits)); |
351 | c->set_output(0, c->Vector(num_row_splits)); |
352 | |
353 | // char_values.shape == [num_chars] |
354 | DimensionHandle num_chars = c->UnknownDim(); |
355 | c->set_output(1, c->Vector(num_chars)); |
356 | return OkStatus(); |
357 | }); |
358 | |
359 | REGISTER_OP("UnicodeDecodeWithOffsets" ) |
360 | .Input("input: string" ) |
361 | .Output("row_splits: Tsplits" ) |
362 | .Output("char_values: int32" ) |
363 | .Output("char_to_byte_starts: int64" ) |
364 | .Attr("input_encoding: string" ) |
365 | .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'" ) |
366 | .Attr("replacement_char: int = 65533" ) // 0xFFFD unicode replacement char |
367 | .Attr("replace_control_characters: bool = false" ) |
368 | .Attr("Tsplits: {int32, int64} = DT_INT64" ) |
369 | .SetShapeFn([](InferenceContext* c) { |
370 | // row_splits.shape == [input.size() + 1] |
371 | DimensionHandle num_row_splits; |
372 | DimensionHandle input_size = c->NumElements(c->input(0)); |
373 | TF_RETURN_IF_ERROR(c->Add(input_size, 1, &num_row_splits)); |
374 | c->set_output(0, c->Vector(num_row_splits)); |
375 | |
376 | // char_values.shape == offset_values.shape == [num_chars] |
377 | DimensionHandle num_chars = c->UnknownDim(); |
378 | c->set_output(1, c->Vector(num_chars)); |
379 | c->set_output(2, c->Vector(num_chars)); |
380 | return OkStatus(); |
381 | }); |
382 | |
383 | REGISTER_OP("StringNGrams" ) |
384 | .Attr("separator: string" ) |
385 | .Attr("ngram_widths: list(int) >= 0" ) |
386 | .Attr("left_pad: string" ) |
387 | .Attr("right_pad: string" ) |
388 | .Attr("pad_width: int" ) |
389 | .Attr("preserve_short_sequences: bool" ) |
390 | .Attr("Tsplits: {int32, int64} = DT_INT64" ) |
391 | .Input("data: string" ) |
392 | .Input("data_splits: Tsplits" ) |
393 | .Output("ngrams: string" ) |
394 | .Output("ngrams_splits: Tsplits" ) |
395 | .SetShapeFn([](InferenceContext* c) { |
396 | c->set_output(0, c->UnknownShapeOfRank(1)); |
397 | ShapeHandle data = c->input(0); |
398 | TF_RETURN_IF_ERROR(c->WithRank(data, 1, &data)); |
399 | ShapeHandle data_splits = c->input(1); |
400 | TF_RETURN_IF_ERROR(c->WithRank(data_splits, 1, &data_splits)); |
401 | c->set_output(1, data_splits); |
402 | return OkStatus(); |
403 | }); |
404 | |
405 | } // namespace tensorflow |
406 | |