1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/framework/common_shape_fns.h"
16#include "tensorflow/core/framework/full_type.pb.h"
17#include "tensorflow/core/framework/op.h"
18
19namespace tensorflow {
20
21REGISTER_OP("AssertCardinalityDataset")
22 .Input("input_dataset: variant")
23 .Input("cardinality: int64")
24 .Output("handle: variant")
25 .Attr("output_types: list(type) >= 1")
26 .Attr("output_shapes: list(shape) >= 1")
27 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
28 "output_types"))
29 .SetShapeFn([](shape_inference::InferenceContext* c) {
30 shape_inference::ShapeHandle unused;
31 // cardinality should be a scalar.
32 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
33 return shape_inference::ScalarShape(c);
34 });
35
36REGISTER_OP("AssertNextDataset")
37 .Input("input_dataset: variant")
38 .Input("transformations: string")
39 .Output("handle: variant")
40 .Attr("output_types: list(type) >= 1")
41 .Attr("output_shapes: list(shape) >= 1")
42 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
43 "output_types"))
44 .SetShapeFn([](shape_inference::InferenceContext* c) {
45 shape_inference::ShapeHandle unused;
46 // transformations should be a vector.
47 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
48 return shape_inference::ScalarShape(c);
49 });
50
51REGISTER_OP("ExperimentalAssertNextDataset")
52 .Input("input_dataset: variant")
53 .Input("transformations: string")
54 .Output("handle: variant")
55 .Attr("output_types: list(type) >= 1")
56 .Attr("output_shapes: list(shape) >= 1")
57 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
58 "output_types"))
59 .SetShapeFn([](shape_inference::InferenceContext* c) {
60 shape_inference::ShapeHandle unused;
61 // transformations should be a vector.
62 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
63 return shape_inference::ScalarShape(c);
64 });
65
66REGISTER_OP("AssertPrevDataset")
67 .Input("input_dataset: variant")
68 .Input("transformations: string")
69 .Output("handle: variant")
70 .Attr("output_types: list(type) >= 1")
71 .Attr("output_shapes: list(shape) >= 1")
72 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
73 "output_types"))
74 .SetShapeFn([](shape_inference::InferenceContext* c) {
75 shape_inference::ShapeHandle unused;
76 // transformations should be a vector.
77 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
78 return shape_inference::ScalarShape(c);
79 });
80
81REGISTER_OP("AutoShardDataset")
82 .Input("input_dataset: variant")
83 .Input("num_workers: int64")
84 .Input("index: int64")
85 .Output("handle: variant")
86 .Attr("auto_shard_policy: int = 0")
87 .Attr("output_types: list(type) >= 1")
88 .Attr("output_shapes: list(shape) >= 1")
89 .Attr("num_replicas: int = 0")
90 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
91 "output_types"))
92 .SetForwardTypeFn(full_type::ContainerMap(TFT_DATASET, /*input_idx=*/0,
93 full_type::ShardTensor))
94 .SetShapeFn(shape_inference::ScalarShape);
95
96REGISTER_OP("ExperimentalAutoShardDataset")
97 .Input("input_dataset: variant")
98 .Input("num_workers: int64")
99 .Input("index: int64")
100 .Output("handle: variant")
101 .Attr("auto_shard_policy: int = 0")
102 .Attr("output_types: list(type) >= 1")
103 .Attr("output_shapes: list(shape) >= 1")
104 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
105 "output_types"))
106 .SetShapeFn(shape_inference::ScalarShape);
107
108REGISTER_OP("BytesProducedStatsDataset")
109 .Input("input_dataset: variant")
110 .Input("tag: string")
111 .Output("handle: variant")
112 .Attr("output_types: list(type) >= 1")
113 .Attr("output_shapes: list(shape) >= 1")
114 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
115 "output_types"))
116 .SetShapeFn([](shape_inference::InferenceContext* c) {
117 shape_inference::ShapeHandle tag_shape;
118 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
119 return shape_inference::ScalarShape(c);
120 });
121
122REGISTER_OP("ExperimentalBytesProducedStatsDataset")
123 .Input("input_dataset: variant")
124 .Input("tag: string")
125 .Output("handle: variant")
126 .Attr("output_types: list(type) >= 1")
127 .Attr("output_shapes: list(shape) >= 1")
128 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
129 "output_types"))
130 .SetShapeFn([](shape_inference::InferenceContext* c) {
131 shape_inference::ShapeHandle tag_shape;
132 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
133 return shape_inference::ScalarShape(c);
134 });
135
136REGISTER_OP("ChooseFastestBranchDataset")
137 .Input("input_dataset: variant")
138 .Input("ratio_numerator: int64")
139 .Input("ratio_denominator: int64")
140 .Input("other_arguments: Targuments")
141 .Output("handle: variant")
142 .Attr("Targuments: list(type) >= 0")
143 .Attr("num_elements_per_branch: int >= 1")
144 .Attr("branches: list(func) >= 1")
145 .Attr("other_arguments_lengths: list(int) >= 1")
146 .Attr("output_types: list(type) >= 1")
147 .Attr("output_shapes: list(shape) >= 1")
148 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
149 "output_types"))
150 .SetShapeFn(shape_inference::ScalarShape);
151
152REGISTER_OP("ChooseFastestDataset")
153 .Input("input_datasets: N * variant")
154 .Output("handle: variant")
155 .Attr("N: int >= 2")
156 .Attr("num_experiments: int")
157 .Attr("output_types: list(type) >= 1")
158 .Attr("output_shapes: list(shape) >= 1")
159 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
160 "output_types"))
161 .SetShapeFn(shape_inference::ScalarShape);
162
163REGISTER_OP("ExperimentalChooseFastestDataset")
164 .Input("input_datasets: N * variant")
165 .Output("handle: variant")
166 .Attr("N: int >= 2")
167 .Attr("num_experiments: int")
168 .Attr("output_types: list(type) >= 1")
169 .Attr("output_shapes: list(shape) >= 1")
170 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
171 "output_types"))
172 .SetShapeFn(shape_inference::ScalarShape);
173
174REGISTER_OP("CompressElement")
175 .Input("components: input_types")
176 .Output("compressed: variant")
177 .Attr("input_types: list(type) >= 1")
178 .SetShapeFn(shape_inference::ScalarShape);
179
180REGISTER_OP("UncompressElement")
181 .Input("compressed: variant")
182 .Output("components: output_types")
183 .Attr("output_types: list(type) >= 1")
184 .Attr("output_shapes: list(shape) >= 1")
185 .SetShapeFn(shape_inference::DatasetIteratorShape);
186
187REGISTER_OP("ComputeBatchSize")
188 .Input("input_dataset : variant")
189 .Output("batch_size : int64")
190 .SetShapeFn(shape_inference::ScalarShape);
191
192REGISTER_OP("CSVDataset")
193 .Input("filenames: string")
194 .Input("compression_type: string")
195 .Input("buffer_size: int64")
196 .Input("header: bool")
197 .Input("field_delim: string")
198 .Input("use_quote_delim: bool")
199 .Input("na_value: string")
200 .Input("select_cols: int64")
201 .Input("record_defaults: output_types")
202 .Output("handle: variant")
203 .Attr("output_types: list({float,double,int32,int64,string}) >= 1")
204 .Attr("output_shapes: list(shape) >= 1")
205 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
206 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
207 "output_types"))
208 .SetShapeFn([](shape_inference::InferenceContext* c) {
209 shape_inference::ShapeHandle unused;
210 // `filenames` must be a scalar or a vector.
211 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
212 // `compression_type`, `buffer_size`, `header`, `field_delim`,
213 // `use_quote_delim`, `na_value` must be scalars
214 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
215 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
216 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
217 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
218 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
219 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
220 // `select_cols` must be a vector
221 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
222 // `record_defaults` must be lists of scalars
223 for (size_t i = 8; i < c->num_inputs(); ++i) {
224 shape_inference::ShapeHandle v;
225 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
226 if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
227 return errors::InvalidArgument(
228 "Shape of a default must be a length-0 or length-1 vector, or a "
229 "scalar.");
230 }
231 }
232 return shape_inference::ScalarShape(c);
233 });
234
235REGISTER_OP("CSVDatasetV2")
236 .Input("filenames: string")
237 .Input("compression_type: string")
238 .Input("buffer_size: int64")
239 .Input("header: bool")
240 .Input("field_delim: string")
241 .Input("use_quote_delim: bool")
242 .Input("na_value: string")
243 .Input("select_cols: int64")
244 .Input("record_defaults: output_types")
245 .Input("exclude_cols: int64")
246 .Output("handle: variant")
247 .Attr("output_types: list({float,double,int32,int64,string}) >= 1")
248 .Attr("output_shapes: list(shape) >= 1")
249 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
250 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
251 "output_types"))
252 .SetShapeFn([](shape_inference::InferenceContext* c) {
253 shape_inference::ShapeHandle unused;
254 // `filenames` must be a scalar or a vector.
255 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
256 // `compression_type`, `buffer_size`, `header`, `field_delim`,
257 // `use_quote_delim`, `na_value` must be scalars
258 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
259 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
260 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
261 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
262 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
263 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
264 // `select_cols` must be a vector
265 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
266 // `exclude_cols` must be a vector
267 TF_RETURN_IF_ERROR(
268 c->WithRank(c->input(c->num_inputs() - 1), 1, &unused));
269 // `record_defaults` must be lists of scalars
270 for (size_t i = 8; i < c->num_inputs() - 1; ++i) {
271 shape_inference::ShapeHandle v;
272 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
273 if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
274 return errors::InvalidArgument(
275 "Shape of a default must be a length-0 or length-1 vector, or a "
276 "scalar.");
277 }
278 }
279 return shape_inference::ScalarShape(c);
280 });
281
282REGISTER_OP("ExperimentalCSVDataset")
283 .Input("filenames: string")
284 .Input("compression_type: string")
285 .Input("buffer_size: int64")
286 .Input("header: bool")
287 .Input("field_delim: string")
288 .Input("use_quote_delim: bool")
289 .Input("na_value: string")
290 .Input("select_cols: int64")
291 .Input("record_defaults: output_types")
292 .Output("handle: variant")
293 .Attr("output_types: list({float,double,int32,int64,string}) >= 1")
294 .Attr("output_shapes: list(shape) >= 1")
295 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
296 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
297 "output_types"))
298 .SetShapeFn([](shape_inference::InferenceContext* c) {
299 shape_inference::ShapeHandle unused;
300 // `filenames` must be a scalar or a vector.
301 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
302 // `compression_type`, `buffer_size`, `header`, `field_delim`,
303 // `use_quote_delim`, `na_value` must be scalars
304 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
305 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
306 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
307 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
308 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
309 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
310 // `select_cols` must be a vector
311 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
312 // `record_defaults` must be lists of scalars
313 for (size_t i = 8; i < c->num_inputs(); ++i) {
314 shape_inference::ShapeHandle v;
315 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
316 if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
317 return errors::InvalidArgument(
318 "Shape of a default must be a length-0 or length-1 vector, or a "
319 "scalar.");
320 }
321 }
322 return shape_inference::ScalarShape(c);
323 });
324
325REGISTER_OP("ExperimentalDatasetCardinality")
326 .Input("input_dataset: variant")
327 .Output("cardinality: int64")
328 .SetShapeFn(shape_inference::ScalarShape);
329
330REGISTER_OP("DatasetFromGraph")
331 .Input("graph_def: string")
332 .Output("handle: variant")
333 .SetTypeConstructor(full_type::UnaryGeneric(TFT_DATASET))
334 .SetForwardTypeFn(full_type::Decode(TFT_STRING, 0))
335 .SetShapeFn(shape_inference::ScalarShape);
336
337// TODO(b/124308596): Instead of conservatively marking this op as stateful,
338// implement a mechanism to determine whether `dataset` has a side-effect
339// and use it to decide whether to use a stateless or stateful version of this
340// op.
341REGISTER_OP("DatasetToTFRecord")
342 .Input("input_dataset: variant")
343 .Input("filename: string")
344 .Input("compression_type: string")
345 .SetIsStateful()
346 .SetShapeFn(shape_inference::NoOutputs);
347
348REGISTER_OP("ExperimentalDatasetToTFRecord")
349 .Input("input_dataset: variant")
350 .Input("filename: string")
351 .Input("compression_type: string")
352 .SetIsStateful()
353 .SetShapeFn(shape_inference::NoOutputs);
354
355REGISTER_OP("DenseToSparseBatchDataset")
356 .Input("input_dataset: variant")
357 .Input("batch_size: int64")
358 .Input("row_shape: int64")
359 .Output("handle: variant")
360 .Attr("output_types: list(type) >= 1")
361 .Attr("output_shapes: list(shape) >= 1")
362 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
363 "output_types"))
364 .SetShapeFn([](shape_inference::InferenceContext* c) {
365 shape_inference::ShapeHandle unused;
366 // batch_size should be a scalar.
367 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
368 // row_shape should be a 1-D vector.
369 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
370 return shape_inference::ScalarShape(c);
371 });
372
373REGISTER_OP("ExperimentalDenseToSparseBatchDataset")
374 .Input("input_dataset: variant")
375 .Input("batch_size: int64")
376 .Input("row_shape: int64")
377 .Output("handle: variant")
378 .Attr("output_types: list(type) >= 1")
379 .Attr("output_shapes: list(shape) >= 1")
380 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
381 "output_types"))
382 .SetShapeFn([](shape_inference::InferenceContext* c) {
383 shape_inference::ShapeHandle unused;
384 // batch_size should be a scalar.
385 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
386 // row_shape should be a 1-D vector.
387 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
388 return shape_inference::ScalarShape(c);
389 });
390
391REGISTER_OP("DirectedInterleaveDataset")
392 .Input("selector_input_dataset: variant")
393 .Input("data_input_datasets: N * variant")
394 .Output("handle: variant")
395 .Attr("output_types: list(type) >= 1")
396 .Attr("output_shapes: list(shape) >= 1")
397 .Attr("N: int >= 1")
398 .Attr("stop_on_empty_dataset: bool = false")
399 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
400 "output_types"))
401 .SetShapeFn(shape_inference::ScalarShape);
402
403REGISTER_OP("ExperimentalDirectedInterleaveDataset")
404 .Input("selector_input_dataset: variant")
405 .Input("data_input_datasets: N * variant")
406 .Output("handle: variant")
407 .Attr("output_types: list(type) >= 1")
408 .Attr("output_shapes: list(shape) >= 1")
409 .Attr("N: int >= 1")
410 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
411 "output_types"))
412 .SetShapeFn(shape_inference::ScalarShape);
413
414REGISTER_OP("GroupByReducerDataset")
415 .Input("input_dataset: variant")
416 .Input("key_func_other_arguments: Tkey_func_other_arguments")
417 .Input("init_func_other_arguments: Tinit_func_other_arguments")
418 .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
419 .Input("finalize_func_other_arguments: Tfinalize_func_other_arguments")
420 .Output("handle: variant")
421 .Attr("key_func: func")
422 .Attr("init_func: func")
423 .Attr("reduce_func: func")
424 .Attr("finalize_func: func")
425 .Attr("Tkey_func_other_arguments: list(type) >= 0")
426 .Attr("Tinit_func_other_arguments: list(type) >= 0")
427 .Attr("Treduce_func_other_arguments: list(type) >= 0")
428 .Attr("Tfinalize_func_other_arguments: list(type) >= 0")
429 .Attr("output_types: list(type) >= 1")
430 .Attr("output_shapes: list(shape) >= 1")
431 .SetIsStateful()
432 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
433 "output_types"))
434 .SetShapeFn(shape_inference::ScalarShape);
435
436REGISTER_OP("ExperimentalGroupByReducerDataset")
437 .Input("input_dataset: variant")
438 .Input("key_func_other_arguments: Tkey_func_other_arguments")
439 .Input("init_func_other_arguments: Tinit_func_other_arguments")
440 .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
441 .Input("finalize_func_other_arguments: Tfinalize_func_other_arguments")
442 .Output("handle: variant")
443 .Attr("key_func: func")
444 .Attr("init_func: func")
445 .Attr("reduce_func: func")
446 .Attr("finalize_func: func")
447 .Attr("Tkey_func_other_arguments: list(type) >= 0")
448 .Attr("Tinit_func_other_arguments: list(type) >= 0")
449 .Attr("Treduce_func_other_arguments: list(type) >= 0")
450 .Attr("Tfinalize_func_other_arguments: list(type) >= 0")
451 .Attr("output_types: list(type) >= 1")
452 .Attr("output_shapes: list(shape) >= 1")
453 .SetIsStateful()
454 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
455 "output_types"))
456 .SetShapeFn(shape_inference::ScalarShape);
457
458REGISTER_OP("GroupByWindowDataset")
459 .Input("input_dataset: variant")
460 .Input("key_func_other_arguments: Tkey_func_other_arguments")
461 .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
462 .Input(
463 "window_size_func_other_arguments: Twindow_size_func_other_arguments")
464 .Output("handle: variant")
465 .Attr("key_func: func")
466 .Attr("reduce_func: func")
467 .Attr("window_size_func: func")
468 .Attr("Tkey_func_other_arguments: list(type) >= 0")
469 .Attr("Treduce_func_other_arguments: list(type) >= 0")
470 .Attr("Twindow_size_func_other_arguments: list(type) >= 0")
471 .Attr("output_types: list(type) >= 1")
472 .Attr("output_shapes: list(shape) >= 1")
473 .Attr("metadata: string = ''")
474 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
475 "output_types"))
476 .SetShapeFn(shape_inference::ScalarShape);
477
478REGISTER_OP("GetElementAtIndex")
479 .Input("dataset: variant")
480 .Input("index: int64")
481 .Output("components: output_types")
482 .Attr("output_types: list(type) >= 1")
483 .Attr("output_shapes: list(shape) >= 1")
484 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
485 "output_types"))
486 .SetShapeFn(shape_inference::DatasetIteratorShape);
487
488REGISTER_OP("ExperimentalGroupByWindowDataset")
489 .Input("input_dataset: variant")
490 .Input("key_func_other_arguments: Tkey_func_other_arguments")
491 .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
492 .Input(
493 "window_size_func_other_arguments: Twindow_size_func_other_arguments")
494 .Output("handle: variant")
495 .Attr("key_func: func")
496 .Attr("reduce_func: func")
497 .Attr("window_size_func: func")
498 .Attr("Tkey_func_other_arguments: list(type) >= 0")
499 .Attr("Treduce_func_other_arguments: list(type) >= 0")
500 .Attr("Twindow_size_func_other_arguments: list(type) >= 0")
501 .Attr("output_types: list(type) >= 1")
502 .Attr("output_shapes: list(shape) >= 1")
503 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
504 "output_types"))
505 .SetShapeFn(shape_inference::ScalarShape);
506
507REGISTER_OP("IgnoreErrorsDataset")
508 .Input("input_dataset: variant")
509 .Output("handle: variant")
510 .Attr("output_types: list(type) >= 1")
511 .Attr("output_shapes: list(shape) >= 1")
512 .Attr("log_warning: bool = false")
513 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
514 "output_types"))
515 .SetShapeFn(shape_inference::ScalarShape);
516
517REGISTER_OP("ExperimentalIgnoreErrorsDataset")
518 .Input("input_dataset: variant")
519 .Output("handle: variant")
520 .Attr("output_types: list(type) >= 1")
521 .Attr("output_shapes: list(shape) >= 1")
522 .Attr("log_warning: bool = false")
523 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
524 "output_types"))
525 .SetShapeFn(shape_inference::ScalarShape);
526
527REGISTER_OP("IteratorGetDevice")
528 .Input("resource: resource")
529 .Output("device: string")
530 .SetShapeFn(shape_inference::ScalarShape);
531
532REGISTER_OP("ExperimentalIteratorGetDevice")
533 .Input("resource: resource")
534 .Output("device: string")
535 .SetShapeFn(shape_inference::ScalarShape);
536
537REGISTER_OP("LatencyStatsDataset")
538 .Input("input_dataset: variant")
539 .Input("tag: string")
540 .Output("handle: variant")
541 .Attr("output_types: list(type) >= 1")
542 .Attr("output_shapes: list(shape) >= 1")
543 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
544 "output_types"))
545 .SetShapeFn([](shape_inference::InferenceContext* c) {
546 shape_inference::ShapeHandle tag_shape;
547 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
548 return shape_inference::ScalarShape(c);
549 });
550
551REGISTER_OP("ExperimentalLatencyStatsDataset")
552 .Input("input_dataset: variant")
553 .Input("tag: string")
554 .Output("handle: variant")
555 .Attr("output_types: list(type) >= 1")
556 .Attr("output_shapes: list(shape) >= 1")
557 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
558 "output_types"))
559 .SetShapeFn([](shape_inference::InferenceContext* c) {
560 shape_inference::ShapeHandle tag_shape;
561 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
562 return shape_inference::ScalarShape(c);
563 });
564
565REGISTER_OP("LMDBDataset")
566 .Input("filenames: string")
567 .Output("handle: variant")
568 .Attr("output_types: list(type) >= 1")
569 .Attr("output_shapes: list(shape) >= 1")
570 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
571 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
572 "output_types"))
573 .SetShapeFn(shape_inference::ScalarShape);
574
575REGISTER_OP("ExperimentalLMDBDataset")
576 .Input("filenames: string")
577 .Output("handle: variant")
578 .Attr("output_types: list(type) >= 1")
579 .Attr("output_shapes: list(shape) >= 1")
580 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
581 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
582 "output_types"))
583 .SetShapeFn(shape_inference::ScalarShape);
584
585REGISTER_OP("MapAndBatchDataset")
586 .Input("input_dataset: variant")
587 .Input("other_arguments: Targuments")
588 .Input("batch_size: int64")
589 .Input("num_parallel_calls: int64")
590 .Input("drop_remainder: bool")
591 .Output("handle: variant")
592 .Attr("f: func")
593 .Attr("Targuments: list(type) >= 0")
594 .Attr("output_types: list(type) >= 1")
595 .Attr("output_shapes: list(shape) >= 1")
596 .Attr("preserve_cardinality: bool = false")
597 .Attr("metadata: string = ''")
598 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
599 "output_types"))
600 .SetShapeFn([](shape_inference::InferenceContext* c) {
601 // Use index from the end to retrieve the Input shapes,
602 // so that to avoid guessing the length of "other_arguments".
603 // batch_size, num_parallel_calls, and drop_remainder are 0-D scalars.
604 shape_inference::ShapeHandle unused;
605 TF_RETURN_IF_ERROR(
606 c->WithRank(c->input(c->num_inputs() - 3), 0, &unused));
607 TF_RETURN_IF_ERROR(
608 c->WithRank(c->input(c->num_inputs() - 2), 0, &unused));
609 TF_RETURN_IF_ERROR(
610 c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
611
612 return shape_inference::ScalarShape(c);
613 });
614
615REGISTER_OP("ExperimentalMapAndBatchDataset")
616 .Input("input_dataset: variant")
617 .Input("other_arguments: Targuments")
618 .Input("batch_size: int64")
619 .Input("num_parallel_calls: int64")
620 .Input("drop_remainder: bool")
621 .Output("handle: variant")
622 .Attr("f: func")
623 .Attr("Targuments: list(type) >= 0")
624 .Attr("output_types: list(type) >= 1")
625 .Attr("output_shapes: list(shape) >= 1")
626 .Attr("preserve_cardinality: bool = false")
627 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
628 "output_types"))
629 .SetShapeFn([](shape_inference::InferenceContext* c) {
630 // Use index from the end to retrieve the Input shapes,
631 // so that to avoid guessing the length of "other_arguments".
632 // batch_size, num_parallel_calls, and drop_remainder are 0-D scalars.
633 shape_inference::ShapeHandle unused;
634 TF_RETURN_IF_ERROR(
635 c->WithRank(c->input(c->num_inputs() - 3), 0, &unused));
636 TF_RETURN_IF_ERROR(
637 c->WithRank(c->input(c->num_inputs() - 2), 0, &unused));
638 TF_RETURN_IF_ERROR(
639 c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
640
641 return shape_inference::ScalarShape(c);
642 });
643
644REGISTER_OP("ExperimentalMapDataset")
645 .Input("input_dataset: variant")
646 .Input("other_arguments: Targuments")
647 .Output("handle: variant")
648 .Attr("f: func")
649 .Attr("Targuments: list(type) >= 0")
650 .Attr("output_types: list(type) >= 1")
651 .Attr("output_shapes: list(shape) >= 1")
652 .Attr("use_inter_op_parallelism: bool = true")
653 .Attr("preserve_cardinality: bool = false")
654 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
655 "output_types"))
656 .SetShapeFn(shape_inference::ScalarShape);
657
658REGISTER_OP("MatchingFilesDataset")
659 .Input("patterns: string")
660 .Output("handle: variant")
661 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
662 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET,
663 TFT_STRING))
664 .SetShapeFn([](shape_inference::InferenceContext* c) {
665 shape_inference::ShapeHandle unused;
666 // `patterns` must be a scalar or a vector.
667 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
668 return shape_inference::ScalarShape(c);
669 });
670
671REGISTER_OP("ExperimentalMatchingFilesDataset")
672 .Input("patterns: string")
673 .Output("handle: variant")
674 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
675 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET,
676 TFT_STRING))
677 .SetShapeFn([](shape_inference::InferenceContext* c) {
678 shape_inference::ShapeHandle unused;
679 // `patterns` must be a scalar or a vector.
680 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
681 return shape_inference::ScalarShape(c);
682 });
683
684REGISTER_OP("MaxIntraOpParallelismDataset")
685 .Input("input_dataset: variant")
686 .Input("max_intra_op_parallelism: int64")
687 .Output("handle: variant")
688 .Attr("output_types: list(type) >= 1")
689 .Attr("output_shapes: list(shape) >= 1")
690 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
691 "output_types"))
692 .SetShapeFn(shape_inference::ScalarShape);
693
694REGISTER_OP("ExperimentalMaxIntraOpParallelismDataset")
695 .Input("input_dataset: variant")
696 .Input("max_intra_op_parallelism: int64")
697 .Output("handle: variant")
698 .Attr("output_types: list(type) >= 1")
699 .Attr("output_shapes: list(shape) >= 1")
700 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
701 "output_types"))
702 .SetShapeFn(shape_inference::ScalarShape);
703
704REGISTER_OP("NonSerializableDataset")
705 .Input("input_dataset: variant")
706 .Output("handle: variant")
707 .Attr("output_types: list(type) >= 1")
708 .Attr("output_shapes: list(shape) >= 1")
709 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
710 "output_types"))
711 .SetShapeFn(shape_inference::ScalarShape);
712
713REGISTER_OP("ExperimentalNonSerializableDataset")
714 .Input("input_dataset: variant")
715 .Output("handle: variant")
716 .Attr("output_types: list(type) >= 1")
717 .Attr("output_shapes: list(shape) >= 1")
718 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
719 "output_types"))
720 .SetShapeFn(shape_inference::ScalarShape);
721
722REGISTER_OP("ParallelInterleaveDataset")
723 .Input("input_dataset: variant")
724 .Input("other_arguments: Targuments")
725 .Input("cycle_length: int64")
726 .Input("block_length: int64")
727 .Input("sloppy: bool")
728 .Input("buffer_output_elements: int64")
729 .Input("prefetch_input_elements: int64")
730 .Output("handle: variant")
731 .Attr("f: func")
732 .Attr("Targuments: list(type) >= 0")
733 .Attr("output_types: list(type) >= 1")
734 .Attr("output_shapes: list(shape) >= 1")
735 .Attr("metadata: string = ''")
736 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
737 "output_types"))
738 .SetShapeFn(shape_inference::ScalarShape);
739
740// This is the V2 of ParallelInterleaveDataset, renamed to differentiate it
741// from the non-experimental ParallelInterleaveDataset op.
742REGISTER_OP("LegacyParallelInterleaveDatasetV2")
743 .Input("input_dataset: variant")
744 .Input("other_arguments: Targuments")
745 .Input("cycle_length: int64")
746 .Input("block_length: int64")
747 .Input("buffer_output_elements: int64")
748 .Input("prefetch_input_elements: int64")
749 .Output("handle: variant")
750 .Attr("f: func")
751 // "true", "false", or "default".
752 .Attr("deterministic: string = 'default'")
753 .Attr("Targuments: list(type) >= 0")
754 .Attr("output_types: list(type) >= 1")
755 .Attr("output_shapes: list(shape) >= 1")
756 .Attr("metadata: string = ''")
757 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
758 "output_types"))
759 .SetShapeFn(shape_inference::ScalarShape);
760
761// This op is no longer used. We keep it so that we can read graphs written by
762// old versions of TensorFlow.
763REGISTER_OP("ExperimentalParallelInterleaveDataset")
764 .Input("input_dataset: variant")
765 .Input("other_arguments: Targuments")
766 .Input("cycle_length: int64")
767 .Input("block_length: int64")
768 .Input("sloppy: bool")
769 .Input("buffer_output_elements: int64")
770 .Input("prefetch_input_elements: int64")
771 .Output("handle: variant")
772 .Attr("f: func")
773 .Attr("Targuments: list(type) >= 0")
774 .Attr("output_types: list(type) >= 1")
775 .Attr("output_shapes: list(shape) >= 1")
776 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
777 "output_types"))
778 .SetShapeFn(shape_inference::ScalarShape);
779
780REGISTER_OP("ParseExampleDataset")
781 .Input("input_dataset: variant")
782 .Input("num_parallel_calls: int64")
783 .Input("dense_defaults: Tdense")
784 .Output("handle: variant")
785 .Attr("sparse_keys: list(string) >= 0")
786 .Attr("dense_keys: list(string) >= 0")
787 .Attr("sparse_types: list({float,int64,string}) >= 0")
788 .Attr("Tdense: list({float,int64,string}) >= 0")
789 .Attr("dense_shapes: list(shape) >= 0")
790 .Attr("output_types: list(type) >= 1")
791 .Attr("output_shapes: list(shape) >= 1") // Output components will be
792 // sorted by key (dense_keys and
793 // sparse_keys combined) here.
794 .Attr("sloppy: bool = false")
795 .Attr("ragged_keys: list(string) >= 0 = []")
796 .Attr("ragged_value_types: list({float,int64,string}) >= 0 = []")
797 .Attr("ragged_split_types: list({int32,int64}) >= 0 = []")
798 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
799 "output_types"))
800 .SetShapeFn(shape_inference::ScalarShape);
801
802REGISTER_OP("ParseExampleDatasetV2")
803 .Input("input_dataset: variant")
804 .Input("num_parallel_calls: int64")
805 .Input("dense_defaults: Tdense")
806 .Output("handle: variant")
807 .Attr("sparse_keys: list(string) >= 0")
808 .Attr("dense_keys: list(string) >= 0")
809 .Attr("sparse_types: list({float,int64,string}) >= 0")
810 .Attr("Tdense: list({float,int64,string}) >= 0")
811 .Attr("dense_shapes: list(shape) >= 0")
812 .Attr("output_types: list(type) >= 1")
813 .Attr("output_shapes: list(shape) >= 1") // Output components will be
814 // sorted by key (dense_keys and
815 // sparse_keys combined) here.
816 // "true", "false", or "default".
817 .Attr("deterministic: string = 'default'")
818 .Attr("ragged_keys: list(string) >= 0 = []")
819 .Attr("ragged_value_types: list({float,int64,string}) >= 0 = []")
820 .Attr("ragged_split_types: list({int32,int64}) >= 0 = []")
821 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
822 "output_types"))
823 .SetShapeFn(shape_inference::ScalarShape);
824
825REGISTER_OP("ExperimentalParseExampleDataset")
826 .Input("input_dataset: variant")
827 .Input("num_parallel_calls: int64")
828 .Input("dense_defaults: Tdense")
829 .Output("handle: variant")
830 .Attr("sparse_keys: list(string) >= 0")
831 .Attr("dense_keys: list(string) >= 0")
832 .Attr("sparse_types: list({float,int64,string}) >= 0")
833 .Attr("Tdense: list({float,int64,string}) >= 0")
834 .Attr("dense_shapes: list(shape) >= 0")
835 .Attr("output_types: list(type) >= 1")
836 .Attr("output_shapes: list(shape) >= 1") // Output components will be
837 // sorted by key (dense_keys and
838 // sparse_keys combined) here.
839 .Attr("sloppy: bool = false")
840 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
841 "output_types"))
842 .SetShapeFn(shape_inference::ScalarShape);
843
844REGISTER_OP("PrivateThreadPoolDataset")
845 .Input("input_dataset: variant")
846 .Input("num_threads: int64")
847 .Output("handle: variant")
848 .Attr("output_types: list(type) >= 1")
849 .Attr("output_shapes: list(shape) >= 1")
850 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
851 "output_types"))
852 .SetShapeFn(shape_inference::ScalarShape);
853
854REGISTER_OP("ExperimentalPrivateThreadPoolDataset")
855 .Input("input_dataset: variant")
856 .Input("num_threads: int64")
857 .Output("handle: variant")
858 .Attr("output_types: list(type) >= 1")
859 .Attr("output_shapes: list(shape) >= 1")
860 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
861 "output_types"))
862 .SetShapeFn(shape_inference::ScalarShape);
863
864REGISTER_OP("ExperimentalRandomDataset")
865 .Input("seed: int64")
866 .Input("seed2: int64")
867 .Output("handle: variant")
868 .Attr("output_types: list(type) >= 1")
869 .Attr("output_shapes: list(shape) >= 1")
870 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
871 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
872 "output_types"))
873 .SetShapeFn([](shape_inference::InferenceContext* c) {
874 shape_inference::ShapeHandle unused;
875 // buffer_size, seed, and seed2 should be scalars.
876 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
877 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
878 return shape_inference::ScalarShape(c);
879 });
880
881REGISTER_OP("RandomDataset")
882 .Input("seed: int64")
883 .Input("seed2: int64")
884 .Output("handle: variant")
885 .Attr("output_types: list(type) >= 1")
886 .Attr("output_shapes: list(shape) >= 1")
887 .Attr("metadata: string = ''")
888 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
889 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
890 "output_types"))
891 .SetShapeFn([](shape_inference::InferenceContext* c) {
892 shape_inference::ShapeHandle unused;
893 // buffer_size, seed, and seed2 should be scalars.
894 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
895 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
896 return shape_inference::ScalarShape(c);
897 });
898
899REGISTER_OP("ExperimentalRebatchDataset")
900 .Input("input_dataset: variant")
901 .Input("num_replicas: int64")
902 .Output("handle: variant")
903 .Attr("output_types: list(type) >= 1")
904 .Attr("output_shapes: list(shape) >= 1")
905 .Attr("use_fallback: bool = true")
906 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
907 "output_types"))
908 .SetShapeFn(shape_inference::ScalarShape);
909
910REGISTER_OP("RebatchDataset")
911 .Input("input_dataset: variant")
912 .Input("num_replicas: int64")
913 .Output("handle: variant")
914 .Attr("output_types: list(type) >= 1")
915 .Attr("output_shapes: list(shape) >= 1")
916 .Attr("use_fallback: bool = true")
917 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
918 "output_types"))
919 .SetForwardTypeFn(full_type::ContainerMap(TFT_DATASET, /*input_idx=*/0,
920 full_type::BatchTensor))
921 .SetShapeFn(shape_inference::ScalarShape);
922
923REGISTER_OP("RebatchDatasetV2")
924 .Input("input_dataset: variant")
925 .Input("batch_sizes: int64")
926 .Input("drop_remainder: bool")
927 .Output("handle: variant")
928 .Attr("output_types: list(type) >= 1")
929 .Attr("output_shapes: list(shape) >= 1")
930 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
931 "output_types"))
932 .SetForwardTypeFn(full_type::ContainerMap(TFT_DATASET, /*input_idx=*/0,
933 full_type::BatchTensor))
934 .SetShapeFn(shape_inference::ScalarShape);
935
936REGISTER_OP("SamplingDataset")
937 .Input("input_dataset: variant")
938 .Input("rate: float32")
939 .Input("seed: int64")
940 .Input("seed2: int64")
941 .Output("handle: variant")
942 .Attr("output_types: list(type) >= 1")
943 .Attr("output_shapes: list(shape) >= 1")
944 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
945 "output_types"))
946 .SetShapeFn([](shape_inference::InferenceContext* c) {
947 shape_inference::ShapeHandle unused;
948 // rate, seed, and seed2 should be scalars.
949 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
950 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
951 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
952 return shape_inference::ScalarShape(c);
953 });
954
955REGISTER_OP("ScanDataset")
956 .Input("input_dataset: variant")
957 .Input("initial_state: Tstate")
958 .Input("other_arguments: Targuments")
959 .Output("handle: variant")
960 .Attr("f: func")
961 .Attr("Tstate: list(type) >= 1")
962 .Attr("Targuments: list(type) >= 0")
963 .Attr("output_types: list(type) >= 1")
964 .Attr("output_shapes: list(shape) >= 1")
965 .Attr("preserve_cardinality: bool = false")
966 .Attr("use_default_device: bool = true")
967 .Attr("metadata: string = ''")
968 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
969 "output_types"))
970 .SetShapeFn(shape_inference::ScalarShape);
971
972REGISTER_OP("ExperimentalScanDataset")
973 .Input("input_dataset: variant")
974 .Input("initial_state: Tstate")
975 .Input("other_arguments: Targuments")
976 .Output("handle: variant")
977 .Attr("f: func")
978 .Attr("Tstate: list(type) >= 1")
979 .Attr("Targuments: list(type) >= 0")
980 .Attr("output_types: list(type) >= 1")
981 .Attr("output_shapes: list(shape) >= 1")
982 .Attr("preserve_cardinality: bool = false")
983 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
984 "output_types"))
985 .SetShapeFn(shape_inference::ScalarShape);
986
987REGISTER_OP("SetStatsAggregatorDataset")
988 .Input("input_dataset: variant")
989 .Input("stats_aggregator: resource")
990 .Input("tag: string")
991 .Input("counter_prefix: string")
992 .Output("handle: variant")
993 .Attr("output_types: list(type) >= 1")
994 .Attr("output_shapes: list(shape) >= 1")
995 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
996 "output_types"))
997 .SetShapeFn(shape_inference::ScalarShape);
998
999REGISTER_OP("ExperimentalSetStatsAggregatorDataset")
1000 .Input("input_dataset: variant")
1001 .Input("stats_aggregator: resource")
1002 .Input("tag: string")
1003 .Input("counter_prefix: string")
1004 .Output("handle: variant")
1005 .Attr("output_types: list(type) >= 1")
1006 .Attr("output_shapes: list(shape) >= 1")
1007 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1008 "output_types"))
1009 .SetShapeFn(shape_inference::ScalarShape);
1010
1011REGISTER_OP("SleepDataset")
1012 .Input("input_dataset: variant")
1013 .Input("sleep_microseconds: int64")
1014 .Output("handle: variant")
1015 .Attr("output_types: list(type) >= 1")
1016 .Attr("output_shapes: list(shape) >= 1")
1017 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1018 "output_types"))
1019 .SetShapeFn([](shape_inference::InferenceContext* c) {
1020 shape_inference::ShapeHandle unused;
1021 // Both inputs are scalar.
1022 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused));
1023 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 0, &unused));
1024 return shape_inference::ScalarShape(c);
1025 });
1026
1027REGISTER_OP("ExperimentalSleepDataset")
1028 .Input("input_dataset: variant")
1029 .Input("sleep_microseconds: int64")
1030 .Output("handle: variant")
1031 .Attr("output_types: list(type) >= 1")
1032 .Attr("output_shapes: list(shape) >= 1")
1033 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1034 "output_types"))
1035 .SetShapeFn([](shape_inference::InferenceContext* c) {
1036 shape_inference::ShapeHandle unused;
1037 // Both inputs are scalar.
1038 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused));
1039 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 0, &unused));
1040 return shape_inference::ScalarShape(c);
1041 });
1042
1043REGISTER_OP("SlidingWindowDataset")
1044 .Input("input_dataset: variant")
1045 .Input("window_size: int64")
1046 .Input("window_shift: int64")
1047 .Input("window_stride: int64")
1048 .Output("handle: variant")
1049 .Attr("drop_remainder: bool = true")
1050 .Attr("output_types: list(type) >= 1")
1051 .Attr("output_shapes: list(shape) >= 1")
1052 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1053 "output_types"))
1054 .SetShapeFn([](shape_inference::InferenceContext* c) {
1055 shape_inference::ShapeHandle unused;
1056 // window_size, window_shift, and window_stride should be scalars.
1057 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1058 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1059 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1060 return shape_inference::ScalarShape(c);
1061 });
1062
1063REGISTER_OP("ExperimentalSlidingWindowDataset")
1064 .Input("input_dataset: variant")
1065 .Input("window_size: int64")
1066 .Input("window_shift: int64")
1067 .Input("window_stride: int64")
1068 .Output("handle: variant")
1069 .Attr("output_types: list(type) >= 1")
1070 .Attr("output_shapes: list(shape) >= 1")
1071 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1072 "output_types"))
1073 .SetShapeFn([](shape_inference::InferenceContext* c) {
1074 shape_inference::ShapeHandle unused;
1075 // window_size, window_shift, and window_stride should be scalars.
1076 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1077 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1078 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1079 return shape_inference::ScalarShape(c);
1080 });
1081
1082REGISTER_OP("SnapshotDataset")
1083 .Input("input_dataset: variant")
1084 .Input("path: string")
1085 .Output("handle: variant")
1086 .Attr("output_types: list(type) >= 1")
1087 .Attr("output_shapes: list(shape) >= 1")
1088 .Attr("compression: string = ''")
1089 .Attr("reader_path_prefix: string = ''")
1090 .Attr("writer_path_prefix: string = ''")
1091 .Attr("shard_size_bytes: int = 10737418240") // 10 GiB default
1092 .Attr("pending_snapshot_expiry_seconds: int = 86400") // 1 day default
1093 .Attr("num_reader_threads: int = 1")
1094 .Attr("reader_buffer_size: int = 1")
1095 .Attr("num_writer_threads: int = 1")
1096 .Attr("writer_buffer_size: int = 1")
1097 .Attr("shuffle_on_read: bool = false")
1098 .Attr("seed: int = 0")
1099 .Attr("seed2: int = 0")
1100 .Attr("mode: string = 'auto'")
1101 .Attr("snapshot_name: string = ''")
1102 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1103 "output_types"))
1104 .SetShapeFn([](shape_inference::InferenceContext* c) {
1105 shape_inference::ShapeHandle unused;
1106 // snapshot_path should be a scalar.
1107 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1108 return shape_inference::ScalarShape(c);
1109 });
1110
1111REGISTER_OP("SnapshotDatasetV2")
1112 .Input("input_dataset: variant")
1113 .Input("path: string")
1114 .Input("reader_func_other_args: Treader_func_args")
1115 .Input("shard_func_other_args: Tshard_func_args")
1116 .Output("handle: variant")
1117 .Attr("output_types: list(type) >= 1")
1118 .Attr("output_shapes: list(shape) >= 1")
1119 .Attr("compression: string = ''")
1120 .Attr("reader_prefix: string = ''")
1121 .Attr("writer_prefix: string = ''")
1122 .Attr("hash_valid: bool = false")
1123 .Attr("hash: int = 0")
1124 .Attr("reader_func: func")
1125 .Attr("shard_func: func")
1126 .Attr("Treader_func_args: list(type) >= 0")
1127 .Attr("Tshard_func_args: list(type) >= 0")
1128 .Attr("metadata: string = ''")
1129 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1130 "output_types"))
1131 .SetShapeFn([](shape_inference::InferenceContext* c) {
1132 shape_inference::ShapeHandle unused;
1133 // `path` should be a scalar.
1134 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1135 return shape_inference::ScalarShape(c);
1136 });
1137
1138REGISTER_OP("SaveDataset")
1139 .Input("input_dataset: variant")
1140 .Input("path: string")
1141 .Input("shard_func_other_args: Tshard_func_args")
1142 .Attr("compression: string = ''")
1143 .Attr("shard_func: func")
1144 .Attr("use_shard_func: bool = true")
1145 .Attr("Tshard_func_args: list(type) >= 0")
1146 .SetIsStateful()
1147 .SetShapeFn([](shape_inference::InferenceContext* c) {
1148 shape_inference::ShapeHandle unused;
1149 // `path` should be a scalar.
1150 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1151 return OkStatus();
1152 });
1153
1154REGISTER_OP("SaveDatasetV2")
1155 .Input("input_dataset: variant")
1156 .Input("path: string")
1157 .Input("shard_func_other_args: Tshard_func_args")
1158 .Output("handle: variant")
1159 .Attr("compression: string = ''")
1160 .Attr("shard_func: func")
1161 .Attr("use_shard_func: bool = true")
1162 .Attr("Tshard_func_args: list(type) >= 0")
1163 .Attr("output_types: list(type) >= 1")
1164 .Attr("output_shapes: list(shape) >= 1")
1165 .SetIsStateful()
1166 .SetShapeFn([](shape_inference::InferenceContext* c) {
1167 shape_inference::ShapeHandle unused;
1168 // `path` should be a scalar.
1169 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1170 return shape_inference::ScalarShape(c);
1171 });
1172
1173REGISTER_OP("LoadDataset")
1174 .Input("path: string")
1175 .Input("reader_func_other_args: Treader_func_args")
1176 .Output("handle: variant")
1177 .Attr("output_types: list(type) >= 1")
1178 .Attr("output_shapes: list(shape) >= 1")
1179 .Attr("compression: string = ''")
1180 .Attr("reader_func: func")
1181 .Attr("Treader_func_args: list(type) >= 0")
1182 .SetIsStateful()
1183 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1184 "output_types"))
1185 .SetShapeFn([](shape_inference::InferenceContext* c) {
1186 shape_inference::ShapeHandle unused;
1187 // `path` should be a scalar.
1188 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
1189 return shape_inference::ScalarShape(c);
1190 });
1191
1192REGISTER_OP("SnapshotDatasetReader")
1193 .Input("shard_dir: string")
1194 .Input("start_index: int64")
1195 .Output("handle: variant")
1196 .Attr("output_types: list(type) >= 1")
1197 .Attr("output_shapes: list(shape) >= 1")
1198 .Attr("compression: string = ''")
1199 .Attr("version: int")
1200 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1201 "output_types"))
1202 .SetShapeFn([](shape_inference::InferenceContext* c) {
1203 shape_inference::ShapeHandle unused;
1204 // `shard_dir` should be a scalar.
1205 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
1206 // `start_index` should be a scalar.
1207 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1208 return shape_inference::ScalarShape(c);
1209 });
1210
1211REGISTER_OP("SnapshotNestedDatasetReader")
1212 .Input("inputs: N * variant")
1213 .Output("handle: variant")
1214 .Attr("output_types: list(type) >= 1")
1215 .Attr("output_shapes: list(shape) >= 1")
1216 .Attr("N: int >= 1")
1217 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1218 "output_types"))
1219 .SetShapeFn(shape_inference::ScalarShape);
1220
1221REGISTER_OP("SqlDataset")
1222 .Input("driver_name: string")
1223 .Input("data_source_name: string")
1224 .Input("query: string")
1225 .Output("handle: variant")
1226 .Attr("output_types: list(type) >= 1")
1227 .Attr("output_shapes: list(shape) >= 1")
1228 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
1229 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1230 "output_types"))
1231 .SetShapeFn([](shape_inference::InferenceContext* c) {
1232 shape_inference::ShapeHandle unused;
1233 // driver_name, data_source_name, and query should be scalars.
1234 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
1235 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1236 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1237 return shape_inference::ScalarShape(c);
1238 });
1239
1240REGISTER_OP("ExperimentalSqlDataset")
1241 .Input("driver_name: string")
1242 .Input("data_source_name: string")
1243 .Input("query: string")
1244 .Output("handle: variant")
1245 .Attr("output_types: list(type) >= 1")
1246 .Attr("output_shapes: list(shape) >= 1")
1247 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
1248 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1249 "output_types"))
1250 .SetShapeFn([](shape_inference::InferenceContext* c) {
1251 shape_inference::ShapeHandle unused;
1252 // driver_name, data_source_name, and query should be scalars.
1253 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
1254 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1255 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1256 return shape_inference::ScalarShape(c);
1257 });
1258
1259REGISTER_OP("StatsAggregatorHandle")
1260 .Output("handle: resource")
1261 .SetShapeFn(shape_inference::ScalarShape)
1262 .Attr("container: string = ''")
1263 .Attr("shared_name: string = ''");
1264
1265REGISTER_OP("ExperimentalStatsAggregatorHandle")
1266 .Output("handle: resource")
1267 .SetShapeFn(shape_inference::ScalarShape)
1268 .Attr("container: string = ''")
1269 .Attr("shared_name: string = ''");
1270
1271REGISTER_OP("StatsAggregatorHandleV2")
1272 .Output("handle: resource")
1273 .SetShapeFn(shape_inference::ScalarShape)
1274 .Attr("container: string = ''")
1275 .Attr("shared_name: string = ''");
1276
1277REGISTER_OP("StatsAggregatorSetSummaryWriter")
1278 .Input("stats_aggregator: resource")
1279 .Input("summary: resource")
1280 .SetShapeFn(shape_inference::NoOutputs);
1281
1282REGISTER_OP("StatsAggregatorSummary")
1283 .Input("iterator: resource")
1284 .Output("summary: string")
1285 .SetShapeFn(shape_inference::ScalarShape);
1286
1287REGISTER_OP("ExperimentalStatsAggregatorSummary")
1288 .Input("iterator: resource")
1289 .Output("summary: string")
1290 .SetShapeFn(shape_inference::ScalarShape);
1291
1292REGISTER_OP("TakeWhileDataset")
1293 .Input("input_dataset: variant")
1294 .Input("other_arguments: Targuments")
1295 .Output("handle: variant")
1296 .Attr("predicate: func")
1297 .Attr("Targuments: list(type) >= 0")
1298 .Attr("output_types: list(type) >= 1")
1299 .Attr("output_shapes: list(shape) >= 1")
1300 .Attr("metadata: string = ''")
1301 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1302 "output_types"))
1303 .SetShapeFn(shape_inference::ScalarShape);
1304
1305REGISTER_OP("ExperimentalTakeWhileDataset")
1306 .Input("input_dataset: variant")
1307 .Input("other_arguments: Targuments")
1308 .Output("handle: variant")
1309 .Attr("predicate: func")
1310 .Attr("Targuments: list(type) >= 0")
1311 .Attr("output_types: list(type) >= 1")
1312 .Attr("output_shapes: list(shape) >= 1")
1313 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1314 "output_types"))
1315 .SetShapeFn(shape_inference::ScalarShape);
1316
1317REGISTER_OP("ThreadPoolDataset")
1318 .Input("input_dataset: variant")
1319 .Input("thread_pool: resource")
1320 .Output("handle: variant")
1321 .Attr("output_types: list(type) >= 1")
1322 .Attr("output_shapes: list(shape) >= 1")
1323 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1324 "output_types"))
1325 .SetShapeFn(shape_inference::ScalarShape);
1326
1327REGISTER_OP("ExperimentalThreadPoolDataset")
1328 .Input("input_dataset: variant")
1329 .Input("thread_pool: resource")
1330 .Output("handle: variant")
1331 .Attr("output_types: list(type) >= 1")
1332 .Attr("output_shapes: list(shape) >= 1")
1333 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1334 "output_types"))
1335 .SetShapeFn(shape_inference::ScalarShape);
1336
1337REGISTER_OP("ThreadPoolHandle")
1338 .Output("handle: resource")
1339 .SetShapeFn(shape_inference::ScalarShape)
1340 .Attr("num_threads: int")
1341 .Attr("max_intra_op_parallelism: int = 1")
1342 .Attr("display_name: string")
1343 .Attr("container: string = ''")
1344 .Attr("shared_name: string = ''");
1345
1346REGISTER_OP("ExperimentalThreadPoolHandle")
1347 .Output("handle: resource")
1348 .SetShapeFn(shape_inference::ScalarShape)
1349 .Attr("num_threads: int")
1350 .Attr("max_intra_op_parallelism: int = 1")
1351 .Attr("display_name: string")
1352 .Attr("container: string = ''")
1353 .Attr("shared_name: string = ''");
1354
1355REGISTER_OP("UnbatchDataset")
1356 .Input("input_dataset: variant")
1357 .Output("handle: variant")
1358 .Attr("output_types: list(type) >= 1")
1359 .Attr("output_shapes: list(shape) >= 1")
1360 .Attr("metadata: string = ''")
1361 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1362 "output_types"))
1363 .SetShapeFn(shape_inference::ScalarShape);
1364
1365REGISTER_OP("ExperimentalUnbatchDataset")
1366 .Input("input_dataset: variant")
1367 .Output("handle: variant")
1368 .Attr("output_types: list(type) >= 1")
1369 .Attr("output_shapes: list(shape) >= 1")
1370 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1371 "output_types"))
1372 .SetShapeFn(shape_inference::ScalarShape);
1373
1374REGISTER_OP("UniqueDataset")
1375 .Input("input_dataset: variant")
1376 .Output("handle: variant")
1377 .Attr("output_types: list(type) >= 1")
1378 .Attr("output_shapes: list(shape) >= 1")
1379 .Attr("metadata: string = ''")
1380 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1381 "output_types"))
1382 .SetShapeFn(shape_inference::ScalarShape);
1383
1384REGISTER_OP("ExperimentalUniqueDataset")
1385 .Input("input_dataset: variant")
1386 .Output("handle: variant")
1387 .Attr("output_types: list(type) >= 1")
1388 .Attr("output_shapes: list(shape) >= 1")
1389 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1390 "output_types"))
1391 .SetShapeFn(shape_inference::ScalarShape);
1392
1393REGISTER_OP("DummyIterationCounter")
1394 .Output("handle: resource")
1395 .SetShapeFn([](shape_inference::InferenceContext* c) {
1396 c->set_output(0, c->Scalar());
1397 return OkStatus();
1398 });
1399
1400REGISTER_OP("DataServiceDataset")
1401 .Input("dataset_id: int64")
1402 .Input("processing_mode: string")
1403 .Input("address: string")
1404 .Input("protocol: string")
1405 .Input("job_name: string")
1406 .Input("max_outstanding_requests: int64")
1407 .Input("iteration_counter: resource")
1408 .Output("handle: variant")
1409 .Attr("task_refresh_interval_hint_ms: int = -1")
1410 .Attr("output_types: list(type) >= 1")
1411 .Attr("output_shapes: list(shape) >= 1")
1412 .Attr("data_transfer_protocol: string = ''")
1413 .Attr("target_workers: string = 'AUTO'")
1414 .Attr("cross_trainer_cache_options: string = ''")
1415 .SetIsStateful()
1416 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1417 "output_types"))
1418 .SetShapeFn(shape_inference::ScalarShape);
1419
1420// Adds `consumer_index` and `num_consumers` arguments to support round-robin
1421// reads.
1422REGISTER_OP("DataServiceDatasetV2")
1423 .Input("dataset_id: int64")
1424 .Input("processing_mode: string")
1425 .Input("address: string")
1426 .Input("protocol: string")
1427 .Input("job_name: string")
1428 .Input("consumer_index: int64")
1429 .Input("num_consumers: int64")
1430 .Input("max_outstanding_requests: int64")
1431 .Input("iteration_counter: resource")
1432 .Output("handle: variant")
1433 .Attr("task_refresh_interval_hint_ms: int = -1")
1434 .Attr("output_types: list(type) >= 1")
1435 .Attr("output_shapes: list(shape) >= 1")
1436 .Attr("data_transfer_protocol: string = ''")
1437 .Attr("target_workers: string = 'AUTO'")
1438 .Attr("cross_trainer_cache_options: string = ''")
1439 .SetIsStateful()
1440 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1441 "output_types"))
1442 .SetShapeFn(shape_inference::ScalarShape);
1443
1444// Adds `uncompress` and `uncompress_fn` attributes to support uncompression.
1445REGISTER_OP("DataServiceDatasetV3")
1446 .Input("dataset_id: int64")
1447 .Input("processing_mode: string")
1448 .Input("address: string")
1449 .Input("protocol: string")
1450 .Input("job_name: string")
1451 .Input("consumer_index: int64")
1452 .Input("num_consumers: int64")
1453 .Input("max_outstanding_requests: int64")
1454 .Input("iteration_counter: resource")
1455 .Output("handle: variant")
1456 .Attr("task_refresh_interval_hint_ms: int = -1")
1457 .Attr("output_types: list(type) >= 1")
1458 .Attr("output_shapes: list(shape) >= 1")
1459 .Attr("data_transfer_protocol: string = ''")
1460 .Attr("target_workers: string = 'AUTO'")
1461 .Attr("uncompress: bool = false")
1462 .Attr("uncompress_fn: func")
1463 .Attr("cross_trainer_cache_options: string = ''")
1464 .SetIsStateful()
1465 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1466 "output_types"))
1467 .SetShapeFn(shape_inference::ScalarShape);
1468
1469// Changes `dataset_id` from int64 to string.
1470REGISTER_OP("DataServiceDatasetV4")
1471 .Input("dataset_id: string")
1472 .Input("processing_mode: string")
1473 .Input("address: string")
1474 .Input("protocol: string")
1475 .Input("job_name: string")
1476 .Input("consumer_index: int64")
1477 .Input("num_consumers: int64")
1478 .Input("max_outstanding_requests: int64")
1479 .Input("iteration_counter: resource")
1480 .Output("handle: variant")
1481 .Attr("task_refresh_interval_hint_ms: int = -1")
1482 .Attr("output_types: list(type) >= 1")
1483 .Attr("output_shapes: list(shape) >= 1")
1484 .Attr("data_transfer_protocol: string = ''")
1485 .Attr("target_workers: string = 'AUTO'")
1486 .Attr("uncompress: bool = false")
1487 .Attr("uncompress_fn: func")
1488 .Attr("cross_trainer_cache_options: string = ''")
1489 .SetIsStateful()
1490 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1491 "output_types"))
1492 .SetShapeFn(shape_inference::ScalarShape);
1493
1494REGISTER_OP("RegisterDataset")
1495 .Input("dataset: variant")
1496 .Input("address: string")
1497 .Input("protocol: string")
1498 .Output("dataset_id: int64")
1499 .Attr("external_state_policy: int")
1500 .Attr("element_spec: string = ''")
1501 .Attr("metadata: string = ''")
1502 .SetShapeFn(shape_inference::ScalarShape);
1503
1504// Changes `dataset_id` from int64 to string.
1505REGISTER_OP("RegisterDatasetV2")
1506 .Input("dataset: variant")
1507 .Input("address: string")
1508 .Input("protocol: string")
1509 .Output("dataset_id: string")
1510 .Attr("external_state_policy: int")
1511 .Attr("element_spec: string = ''")
1512 .Attr("requested_dataset_id: string = ''")
1513 .Attr("metadata: string = ''")
1514 .SetShapeFn(shape_inference::ScalarShape);
1515
1516REGISTER_OP("InitializeTableFromDataset")
1517 .Input("table_handle: resource")
1518 .Input("dataset: variant")
1519 .SetShapeFn([](shape_inference::InferenceContext* c) {
1520 shape_inference::ShapeHandle handle;
1521 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
1522 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle));
1523 return OkStatus();
1524 });
1525
1526// - `output_types` is the types of tensors in a single dataset element.
1527// - `output_shapes` is the shapes of tensors in a single dataset element.
1528// - `output_types` and `output_shapes` are the same size: the number of
1529// tensors in a single dataset element, a.k.a. the number of components.
1530// - `Tinput_types` is the types of tensors for all dataset elements.
1531// `Tinput_types` is equivalent to `output_types` repeated for N total dataset
1532// elements.
1533REGISTER_OP("ListDataset")
1534 .Input("tensors: Tinput_types")
1535 .Output("handle: variant")
1536 .Attr("Tinput_types: list(type) >= 1")
1537 .Attr("output_types: list(type) >= 1")
1538 .Attr("output_shapes: list(shape) >= 1")
1539 .Attr("metadata: string = ''")
1540 .SetDoNotOptimize()
1541 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1542 "output_types"))
1543 .SetShapeFn(shape_inference::ScalarShape);
1544
1545} // namespace tensorflow
1546