1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | #include "tensorflow/core/lib/core/errors.h" |
20 | #include "tensorflow/core/util/example_proto_helper.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | using shape_inference::DimensionHandle; |
25 | using shape_inference::InferenceContext; |
26 | using shape_inference::ShapeHandle; |
27 | |
28 | namespace { |
29 | |
30 | // Adds output shapes for dense tensors in Parse*Example ops. |
31 | template <typename TensorShapeType> // TensorShape or PartialTensorShape |
32 | Status AddDenseOutputShapes(const std::vector<TensorShapeType>& dense_shapes, |
33 | const ShapeHandle& prefix, InferenceContext* c, |
34 | int* output_idx) { |
35 | for (const auto& dense_shape : dense_shapes) { |
36 | ShapeHandle s; |
37 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(dense_shape, &s)); |
38 | TF_RETURN_IF_ERROR(c->Concatenate(prefix, s, &s)); |
39 | c->set_output((*output_idx)++, s); |
40 | } |
41 | return OkStatus(); |
42 | } |
43 | |
44 | // Adds output shapes for sparse tensors in Parse*Example ops. |
45 | void AddSparseOutputShapes(int num_sparse, const ShapeHandle input_shape, |
46 | int64_t rank_delta, InferenceContext* c, |
47 | int* output_idx) { |
48 | // Rank of SparseTensor is rank of input tensor plus rank_delta. |
49 | shape_inference::DimensionOrConstant rank(c->UnknownDim()); |
50 | if (c->RankKnown(input_shape)) { |
51 | rank = c->Rank(input_shape) + rank_delta; |
52 | } |
53 | for (int i = 0; i < num_sparse; ++i) { // sparse_indices |
54 | c->set_output((*output_idx)++, c->Matrix(c->UnknownDim(), rank)); |
55 | } |
56 | for (int i = 0; i < num_sparse; ++i) { // sparse_values |
57 | c->set_output((*output_idx)++, c->Vector(c->UnknownDim())); |
58 | } |
59 | for (int i = 0; i < num_sparse; ++i) { // sparse_dense_shapes |
60 | c->set_output((*output_idx)++, c->Vector(rank)); |
61 | } |
62 | } |
63 | |
64 | // Adds output shapes for ragged tensors in Parse*Example ops. |
65 | Status AddRaggedOutputShapes(int num_ragged, bool ragged_rank_2, |
66 | const DimensionHandle& num_examples, |
67 | InferenceContext* c, int* output_idx) { |
68 | DimensionHandle num_splits; |
69 | TF_RETURN_IF_ERROR(c->Add(num_examples, 1, &num_splits)); |
70 | // Values |
71 | for (int i = 0; i < num_ragged; ++i) { |
72 | c->set_output((*output_idx)++, c->Vector(c->UnknownDim())); |
73 | } |
74 | // Outer row_splits. |
75 | for (int i = 0; i < num_ragged; ++i) { |
76 | c->set_output((*output_idx)++, c->Vector(num_splits)); |
77 | } |
78 | // Inner row_splits (for ParseSequenceExample feature_list features) |
79 | if (ragged_rank_2) { |
80 | for (int i = 0; i < num_ragged; ++i) { |
81 | c->set_output((*output_idx)++, c->Vector(c->UnknownDim())); |
82 | } |
83 | } |
84 | return OkStatus(); |
85 | } |
86 | |
87 | // Adds output shapes for dense_lengths tensors in Parse*Example ops. |
88 | void AddDenseLengthsShapes(int num_dense, const ShapeHandle& shape, |
89 | InferenceContext* c, int* output_idx) { |
90 | for (int i = 0; i < num_dense; ++i) { |
91 | c->set_output((*output_idx)++, shape); |
92 | } |
93 | } |
94 | |
95 | } // namespace |
96 | |
97 | REGISTER_OP("DecodeRaw" ) |
98 | .Input("bytes: string" ) |
99 | .Output("output: out_type" ) |
100 | .Attr( |
101 | "out_type: " |
102 | "{half,float,double,int32,uint16,uint8,int16,int8,int64,complex64," |
103 | "complex128,bool,bfloat16}" ) |
104 | .Attr("little_endian: bool = true" ) |
105 | .SetShapeFn([](InferenceContext* c) { |
106 | // Note: last dimension is data dependent. |
107 | ShapeHandle out; |
108 | TF_RETURN_IF_ERROR(c->Concatenate( |
109 | c->input(0), c->Vector(InferenceContext::kUnknownDim), &out)); |
110 | c->set_output(0, out); |
111 | return OkStatus(); |
112 | }); |
113 | |
114 | REGISTER_OP("DecodePaddedRaw" ) |
115 | .Input("input_bytes: string" ) |
116 | .Input("fixed_length: int32" ) |
117 | .Output("output: out_type" ) |
118 | .Attr( |
119 | "out_type: {half,float,double,int32,uint16,uint8,int16,int8,int64," |
120 | "bfloat16}" ) |
121 | .Attr("little_endian: bool = true" ) |
122 | .SetShapeFn([](InferenceContext* c) { |
123 | DimensionHandle fixed_length; |
124 | TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &fixed_length)); |
125 | |
126 | DataType out_type; |
127 | TF_RETURN_IF_ERROR(c->GetAttr("out_type" , &out_type)); |
128 | |
129 | int32_t data_type_size = DataTypeSize(out_type); |
130 | |
131 | DimensionHandle width; |
132 | TF_RETURN_IF_ERROR(c->Divide(fixed_length, data_type_size, true, &width)); |
133 | |
134 | ShapeHandle out; |
135 | TF_RETURN_IF_ERROR(c->Concatenate(c->input(0), c->Vector(width), &out)); |
136 | |
137 | c->set_output(0, out); |
138 | return OkStatus(); |
139 | }); |
140 | |
141 | REGISTER_OP("DecodeCompressed" ) |
142 | .Input("bytes: string" ) |
143 | .Output("output: string" ) |
144 | .Attr("compression_type: string = ''" ) |
145 | .SetShapeFn(shape_inference::UnchangedShape); |
146 | |
147 | REGISTER_OP("ParseExample" ) |
148 | .Input("serialized: string" ) |
149 | .Input("names: string" ) |
150 | .Input("sparse_keys: Nsparse * string" ) |
151 | .Input("dense_keys: Ndense * string" ) |
152 | .Input("dense_defaults: Tdense" ) |
153 | .Output("sparse_indices: Nsparse * int64" ) |
154 | .Output("sparse_values: sparse_types" ) |
155 | .Output("sparse_shapes: Nsparse * int64" ) |
156 | .Output("dense_values: Tdense" ) |
157 | .Attr("Nsparse: int >= 0" ) // Inferred from sparse_keys |
158 | .Attr("Ndense: int >= 0" ) // Inferred from dense_keys |
159 | .Attr("sparse_types: list({float,int64,string}) >= 0" ) |
160 | .Attr("Tdense: list({float,int64,string}) >= 0" ) |
161 | .Attr("dense_shapes: list(shape) >= 0" ) |
162 | .SetShapeFn([](InferenceContext* c) { |
163 | ParseExampleAttrs attrs; |
164 | TF_RETURN_IF_ERROR(attrs.Init(c, /*op_version=*/1)); |
165 | |
166 | ShapeHandle input; |
167 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input)); |
168 | ShapeHandle names; |
169 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &names)); |
170 | |
171 | int output_idx = 0; |
172 | AddSparseOutputShapes(attrs.num_sparse, input, 1, c, &output_idx); |
173 | TF_RETURN_IF_ERROR( |
174 | AddDenseOutputShapes(attrs.dense_shapes, input, c, &output_idx)); |
175 | return OkStatus(); |
176 | }); |
177 | |
178 | // Differences between ParseExample and ParseExampleV2: |
179 | // * Supports ragged features. |
180 | // * `serialized` may be a vector or a scalar. (With v1, `serialized` could |
181 | // only be a vector). |
182 | // * Each set of keys is passed with a vector instead of a list of scalars. |
183 | // * No Ndense attribute (not needed). |
184 | // * num_sparse (formerly Nsparse) is no longer inferred; you must specify it |
185 | // explicitly. |
186 | REGISTER_OP("ParseExampleV2" ) |
187 | .Input("serialized: string" ) |
188 | .Input("names: string" ) |
189 | .Input("sparse_keys: string" ) |
190 | .Input("dense_keys: string" ) |
191 | .Input("ragged_keys: string" ) |
192 | .Input("dense_defaults: Tdense" ) |
193 | .Output("sparse_indices: num_sparse * int64" ) |
194 | .Output("sparse_values: sparse_types" ) |
195 | .Output("sparse_shapes: num_sparse * int64" ) |
196 | .Output("dense_values: Tdense" ) |
197 | .Output("ragged_values: ragged_value_types" ) |
198 | .Output("ragged_row_splits: ragged_split_types" ) |
199 | .Attr("Tdense: list({float,int64,string}) >= 0" ) // Inferred |
200 | .Attr("num_sparse: int >= 0" ) |
201 | .Attr("sparse_types: list({float,int64,string}) >= 0" ) |
202 | .Attr("ragged_value_types: list({float,int64,string}) >= 0" ) |
203 | .Attr("ragged_split_types: list({int32,int64}) >= 0" ) |
204 | .Attr("dense_shapes: list(shape) >= 0" ) |
205 | |
206 | .SetShapeFn([](InferenceContext* c) { |
207 | ParseExampleAttrs attrs; |
208 | TF_RETURN_IF_ERROR(attrs.Init(c, /*op_version=*/2)); |
209 | |
210 | ShapeHandle input; |
211 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &input)); |
212 | ShapeHandle names; |
213 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &names)); |
214 | DimensionHandle num_examples = c->UnknownDim(); |
215 | if (c->RankKnown(input) && c->Rank(input) == 1) { |
216 | num_examples = c->Dim(input, 0); |
217 | } |
218 | |
219 | int output_idx = 0; |
220 | AddSparseOutputShapes(attrs.num_sparse, input, 1, c, &output_idx); |
221 | TF_RETURN_IF_ERROR( |
222 | AddDenseOutputShapes(attrs.dense_shapes, input, c, &output_idx)); |
223 | TF_RETURN_IF_ERROR(AddRaggedOutputShapes(attrs.num_ragged, false, |
224 | num_examples, c, &output_idx)); |
225 | |
226 | return OkStatus(); |
227 | }); |
228 | |
229 | REGISTER_OP("ParseSingleExample" ) |
230 | .Input("serialized: string" ) |
231 | .Input("dense_defaults: Tdense" ) |
232 | .Output("sparse_indices: num_sparse * int64" ) |
233 | .Output("sparse_values: sparse_types" ) |
234 | .Output("sparse_shapes: num_sparse * int64" ) |
235 | .Output("dense_values: Tdense" ) |
236 | .Attr("num_sparse: int >= 0" ) |
237 | .Attr("sparse_keys: list(string) >= 0" ) |
238 | .Attr("dense_keys: list(string) >= 0" ) |
239 | .Attr("sparse_types: list({float,int64,string}) >= 0" ) |
240 | .Attr("Tdense: list({float,int64,string}) >= 0" ) |
241 | .Attr("dense_shapes: list(shape) >= 0" ) |
242 | .SetShapeFn([](InferenceContext* c) { |
243 | ParseSingleExampleAttrs attrs; |
244 | TF_RETURN_IF_ERROR(attrs.Init(c)); |
245 | |
246 | ShapeHandle input; |
247 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &input)); |
248 | |
249 | int output_idx = 0; |
250 | AddSparseOutputShapes(attrs.sparse_keys.size(), input, 1, c, &output_idx); |
251 | TF_RETURN_IF_ERROR( |
252 | AddDenseOutputShapes(attrs.dense_shapes, input, c, &output_idx)); |
253 | return OkStatus(); |
254 | }); |
255 | |
256 | REGISTER_OP("ParseSequenceExample" ) |
257 | .Input("serialized: string" ) |
258 | .Input("debug_name: string" ) |
259 | .Input("context_dense_defaults: Tcontext_dense" ) |
260 | .Output("context_sparse_indices: Ncontext_sparse * int64" ) |
261 | .Output("context_sparse_values: context_sparse_types" ) |
262 | .Output("context_sparse_shapes: Ncontext_sparse * int64" ) |
263 | .Output("context_dense_values: Tcontext_dense" ) |
264 | .Output("feature_list_sparse_indices: Nfeature_list_sparse * int64" ) |
265 | .Output("feature_list_sparse_values: feature_list_sparse_types" ) |
266 | .Output("feature_list_sparse_shapes: Nfeature_list_sparse * int64" ) |
267 | .Output("feature_list_dense_values: feature_list_dense_types" ) |
268 | .Output("feature_list_dense_lengths: Nfeature_list_dense * int64" ) |
269 | .Attr("feature_list_dense_missing_assumed_empty: list(string) >= 0" ) |
270 | .Attr("context_sparse_keys: list(string) >= 0" ) |
271 | .Attr("context_dense_keys: list(string) >= 0" ) |
272 | .Attr("feature_list_sparse_keys: list(string) >= 0" ) |
273 | .Attr("feature_list_dense_keys: list(string) >= 0" ) |
274 | .Attr("Ncontext_sparse: int >= 0 = 0" ) |
275 | .Attr("Ncontext_dense: int >= 0 = 0" ) |
276 | .Attr("Nfeature_list_sparse: int >= 0 = 0" ) |
277 | .Attr("Nfeature_list_dense: int >= 0 = 0" ) |
278 | .Attr("context_sparse_types: list({float,int64,string}) >= 0 = []" ) |
279 | .Attr("Tcontext_dense: list({float,int64,string}) >= 0 = []" ) |
280 | .Attr("feature_list_dense_types: list({float,int64,string}) >= 0 = []" ) |
281 | .Attr("context_dense_shapes: list(shape) >= 0 = []" ) |
282 | .Attr("feature_list_sparse_types: list({float,int64,string}) >= 0 = []" ) |
283 | .Attr("feature_list_dense_shapes: list(shape) >= 0 = []" ) |
284 | .SetShapeFn([](InferenceContext* c) { |
285 | ParseSequenceExampleAttrs attrs; |
286 | TF_RETURN_IF_ERROR(attrs.Init(c)); |
287 | |
288 | // Verify that the input is a vector, and carry the shape if known. |
289 | ShapeHandle input; |
290 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input)); |
291 | ShapeHandle names; |
292 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &names)); |
293 | DimensionHandle num_examples = c->Dim(input, 0); |
294 | ShapeHandle feature_list_dense_prefix = |
295 | c->Matrix(num_examples, c->UnknownDim()); |
296 | |
297 | int output_idx = 0; |
298 | AddSparseOutputShapes(attrs.num_context_sparse, input, 1, c, &output_idx); |
299 | TF_RETURN_IF_ERROR(AddDenseOutputShapes(attrs.context_dense_shapes, input, |
300 | c, &output_idx)); |
301 | AddSparseOutputShapes(attrs.num_feature_list_sparse, input, 2, c, |
302 | &output_idx); |
303 | TF_RETURN_IF_ERROR(AddDenseOutputShapes(attrs.feature_list_dense_shapes, |
304 | feature_list_dense_prefix, c, |
305 | &output_idx)); |
306 | AddDenseLengthsShapes(attrs.num_feature_list_dense, input, c, |
307 | &output_idx); |
308 | |
309 | return OkStatus(); |
310 | }); |
311 | |
312 | // Differences between ParseSequenceExample and ParseSequenceExampleV2: |
313 | // * Supports ragged features. |
314 | // * `serialized` may be a vector or a scalar. (With v1, `serialized` could |
315 | // only be a vector). |
316 | // * Each set of keys is passed with a vector instead of an attr list. |
317 | // * feature_list_dense_missing_assumed_empty is passed with as a boolean |
318 | // vector (aligned 1:1 w/ feature_list_dense_kyes) rather than an attrib |
319 | // containing a list of strings. |
320 | // * No Ncontext_dense attribute (not needed). |
321 | REGISTER_OP("ParseSequenceExampleV2" ) |
322 | .Input("serialized: string" ) |
323 | .Input("debug_name: string" ) |
324 | // Inputs: context features |
325 | .Input("context_sparse_keys: string" ) |
326 | .Input("context_dense_keys: string" ) |
327 | .Input("context_ragged_keys: string" ) |
328 | // Inputs: feature lists |
329 | .Input("feature_list_sparse_keys: string" ) |
330 | .Input("feature_list_dense_keys: string" ) |
331 | .Input("feature_list_ragged_keys: string" ) |
332 | .Input("feature_list_dense_missing_assumed_empty: bool" ) |
333 | .Input("context_dense_defaults: Tcontext_dense" ) |
334 | // Outputs: context features |
335 | .Output("context_sparse_indices: Ncontext_sparse * int64" ) |
336 | .Output("context_sparse_values: context_sparse_types" ) |
337 | .Output("context_sparse_shapes: Ncontext_sparse * int64" ) |
338 | .Output("context_dense_values: Tcontext_dense" ) |
339 | .Output("context_ragged_values: context_ragged_value_types" ) |
340 | .Output("context_ragged_row_splits: context_ragged_split_types" ) |
341 | // Outputs: feature lists |
342 | .Output("feature_list_sparse_indices: Nfeature_list_sparse * int64" ) |
343 | .Output("feature_list_sparse_values: feature_list_sparse_types" ) |
344 | .Output("feature_list_sparse_shapes: Nfeature_list_sparse * int64" ) |
345 | .Output("feature_list_dense_values: feature_list_dense_types" ) |
346 | .Output("feature_list_dense_lengths: Nfeature_list_dense * int64" ) |
347 | .Output("feature_list_ragged_values: feature_list_ragged_value_types" ) |
348 | .Output("feature_list_ragged_outer_splits: feature_list_ragged_split_types" ) |
349 | .Output("feature_list_ragged_inner_splits: feature_list_ragged_split_types" ) |
350 | // Attribs: context features |
351 | .Attr("Ncontext_sparse: int >= 0 = 0" ) |
352 | .Attr("Tcontext_dense: list({float,int64,string}) >= 0 = []" ) // inferred |
353 | .Attr("context_sparse_types: list({float,int64,string}) >= 0 = []" ) |
354 | .Attr("context_ragged_value_types: list({float,int64,string}) >= 0 = []" ) |
355 | .Attr("context_ragged_split_types: list({int32,int64}) >= 0 = []" ) |
356 | .Attr("context_dense_shapes: list(shape) >= 0 = []" ) |
357 | // Attribs: feature lists |
358 | .Attr("Nfeature_list_sparse: int >= 0 = 0" ) |
359 | .Attr("Nfeature_list_dense: int >= 0 = 0" ) |
360 | .Attr("feature_list_dense_types: list({float,int64,string}) >= 0 = []" ) |
361 | .Attr("feature_list_sparse_types: list({float,int64,string}) >= 0 = []" ) |
362 | .Attr( |
363 | "feature_list_ragged_value_types: list({float,int64,string}) >= 0 = []" ) |
364 | .Attr("feature_list_ragged_split_types: list({int32,int64}) >= 0 = []" ) |
365 | .Attr("feature_list_dense_shapes: list(shape) >= 0 = []" ) |
366 | .SetShapeFn([](InferenceContext* c) { |
367 | ParseSequenceExampleAttrs attrs; |
368 | TF_RETURN_IF_ERROR(attrs.Init(c, /*op_version=*/2)); |
369 | ShapeHandle input; |
370 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &input)); |
371 | ShapeHandle names; |
372 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &names)); |
373 | ShapeHandle feature_list_dense_prefix; |
374 | TF_RETURN_IF_ERROR(c->Concatenate(input, c->UnknownShapeOfRank(1), |
375 | &feature_list_dense_prefix)); |
376 | DimensionHandle num_examples = c->UnknownDim(); |
377 | if (c->RankKnown(input) && c->Rank(input) == 1) { |
378 | num_examples = c->Dim(input, 0); |
379 | } |
380 | |
381 | int output_idx = 0; |
382 | // Context outputs. |
383 | AddSparseOutputShapes(attrs.num_context_sparse, input, 1, c, &output_idx); |
384 | TF_RETURN_IF_ERROR(AddDenseOutputShapes(attrs.context_dense_shapes, input, |
385 | c, &output_idx)); |
386 | TF_RETURN_IF_ERROR(AddRaggedOutputShapes(attrs.num_context_ragged, false, |
387 | num_examples, c, &output_idx)); |
388 | // FeatureList outputs. |
389 | AddSparseOutputShapes(attrs.num_feature_list_sparse, input, 2, c, |
390 | &output_idx); |
391 | TF_RETURN_IF_ERROR(AddDenseOutputShapes(attrs.feature_list_dense_shapes, |
392 | feature_list_dense_prefix, c, |
393 | &output_idx)); |
394 | AddDenseLengthsShapes(attrs.num_feature_list_dense, input, c, |
395 | &output_idx); |
396 | TF_RETURN_IF_ERROR(AddRaggedOutputShapes( |
397 | attrs.num_feature_list_ragged, true, num_examples, c, &output_idx)); |
398 | return OkStatus(); |
399 | }); |
400 | |
401 | REGISTER_OP("ParseSingleSequenceExample" ) |
402 | .Input("serialized: string" ) |
403 | .Input("feature_list_dense_missing_assumed_empty: string" ) |
404 | .Input("context_sparse_keys: Ncontext_sparse * string" ) |
405 | .Input("context_dense_keys: Ncontext_dense * string" ) |
406 | .Input("feature_list_sparse_keys: Nfeature_list_sparse * string" ) |
407 | .Input("feature_list_dense_keys: Nfeature_list_dense * string" ) |
408 | .Input("context_dense_defaults: Tcontext_dense" ) |
409 | .Input("debug_name: string" ) |
410 | .Output("context_sparse_indices: Ncontext_sparse * int64" ) |
411 | .Output("context_sparse_values: context_sparse_types" ) |
412 | .Output("context_sparse_shapes: Ncontext_sparse * int64" ) |
413 | .Output("context_dense_values: Tcontext_dense" ) |
414 | .Output("feature_list_sparse_indices: Nfeature_list_sparse * int64" ) |
415 | .Output("feature_list_sparse_values: feature_list_sparse_types" ) |
416 | .Output("feature_list_sparse_shapes: Nfeature_list_sparse * int64" ) |
417 | .Output("feature_list_dense_values: feature_list_dense_types" ) |
418 | // Infer from context_sparse_keys |
419 | .Attr("Ncontext_sparse: int >= 0 = 0" ) |
420 | // Infer from context_dense_keys |
421 | .Attr("Ncontext_dense: int >= 0 = 0" ) |
422 | // Infer from feature_list_sparse_keys |
423 | .Attr("Nfeature_list_sparse: int >= 0 = 0" ) |
424 | // Infer from feature_list_dense_keys |
425 | .Attr("Nfeature_list_dense: int >= 0 = 0" ) |
426 | .Attr("context_sparse_types: list({float,int64,string}) >= 0 = []" ) |
427 | .Attr("Tcontext_dense: list({float,int64,string}) >= 0 = []" ) |
428 | .Attr("feature_list_dense_types: list({float,int64,string}) >= 0 = []" ) |
429 | .Attr("context_dense_shapes: list(shape) >= 0 = []" ) |
430 | .Attr("feature_list_sparse_types: list({float,int64,string}) >= 0 = []" ) |
431 | .Attr("feature_list_dense_shapes: list(shape) >= 0 = []" ) |
432 | .SetShapeFn([](InferenceContext* c) { |
433 | ShapeHandle unused; |
434 | ParseSingleSequenceExampleAttrs attrs; |
435 | TF_RETURN_IF_ERROR(attrs.Init(c)); |
436 | |
437 | ShapeHandle input; |
438 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &input)); |
439 | |
440 | // feature_list_dense_missing_assumed_empty |
441 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
442 | |
443 | int output_idx = 0; |
444 | AddSparseOutputShapes(attrs.num_context_sparse, input, 1, c, &output_idx); |
445 | TF_RETURN_IF_ERROR(AddDenseOutputShapes(attrs.context_dense_shapes, input, |
446 | c, &output_idx)); |
447 | AddSparseOutputShapes(attrs.num_feature_list_sparse, input, 2, c, |
448 | &output_idx); |
449 | TF_RETURN_IF_ERROR(AddDenseOutputShapes(attrs.feature_list_dense_shapes, |
450 | c->UnknownShapeOfRank(1), c, |
451 | &output_idx)); |
452 | return OkStatus(); |
453 | }); |
454 | |
455 | REGISTER_OP("ParseTensor" ) |
456 | .Input("serialized: string" ) |
457 | .Output("output: out_type" ) |
458 | .Attr("out_type: type" ) |
459 | .SetShapeFn(shape_inference::UnknownShape); |
460 | |
461 | REGISTER_OP("SerializeTensor" ) |
462 | .Input("tensor: T" ) |
463 | .Output("serialized: string" ) |
464 | .Attr("T: type" ) |
465 | .SetShapeFn(shape_inference::ScalarShape); |
466 | |
467 | REGISTER_OP("DecodeJSONExample" ) |
468 | .Input("json_examples: string" ) |
469 | .Output("binary_examples: string" ) |
470 | .SetShapeFn(shape_inference::UnchangedShape); |
471 | |
472 | REGISTER_OP("DecodeCSV" ) |
473 | .Input("records: string" ) |
474 | .Input("record_defaults: OUT_TYPE" ) |
475 | .Output("output: OUT_TYPE" ) |
476 | .Attr("OUT_TYPE: list({float,double,int32,int64,string})" ) |
477 | .Attr("field_delim: string = ','" ) |
478 | .Attr("use_quote_delim: bool = true" ) |
479 | .Attr("na_value: string = ''" ) |
480 | .Attr("select_cols: list(int) = []" ) |
481 | .SetShapeFn([](InferenceContext* c) { |
482 | // Validate the record_defaults inputs. |
483 | for (int i = 1; i < c->num_inputs(); ++i) { |
484 | ShapeHandle v; |
485 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v)); |
486 | if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) { |
487 | return errors::InvalidArgument( |
488 | "Shape of a default must be a length-0 or length-1 vector, or a " |
489 | "scalar." ); |
490 | } |
491 | } |
492 | |
493 | // Propagate shape of the records input. |
494 | for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, c->input(0)); |
495 | return OkStatus(); |
496 | }); |
497 | |
498 | REGISTER_OP("StringToNumber" ) |
499 | .Input("string_tensor: string" ) |
500 | .Output("output: out_type" ) |
501 | .Attr("out_type: {float, double, int32, int64} = DT_FLOAT" ) |
502 | .SetShapeFn(shape_inference::UnchangedShape); |
503 | |
504 | } // namespace tensorflow |
505 | |