1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/framework/common_shape_fns.h"
16#include "tensorflow/core/framework/full_type.pb.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/op_def_builder.h"
19#include "tensorflow/core/framework/shape_inference.h"
20
21namespace tensorflow {
22
23// --------------------------------------------------------------------------
24
25// The ops in this section can be composed to define an input
26// pipeline. Each op produces a DT_VARIANT tensor that represents
27// a DAG of "dataset" objects. An "dataset" object can be converted
28// to a stateful "iterator" by passing the "dataset" to the
29// "MakeIterator" op.
30//
31// TODO(b/123753214): DT_VARIANT tensors that represent "dataset" objects are
32// not presently serializable. To avoid issues with graph optimizations, such
33// as constant folding, CSE, or DCE, ensure that any "source dataset" ops
34// (i.e. ops that output a dataset and do not take one as input) are
35// marked as "do not optimize".
36
37// TODO(mrry): Validate that `components` have shapes compatible with
38// `output_shapes`.
39REGISTER_OP("TensorDataset")
40 .Input("components: Toutput_types")
41 .Output("handle: variant")
42 .Attr("Toutput_types: list(type) >= 1")
43 .Attr("output_shapes: list(shape) >= 1")
44 .Attr("metadata: string = ''")
45 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
46 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
47 "Toutput_types"))
48 .SetShapeFn(shape_inference::ScalarShape);
49
50// TODO(mrry): Validate that the dim-0 slices of `components` have shapes
51// compatible with `output_shapes`.
52REGISTER_OP("TensorSliceDataset")
53 .Input("components: Toutput_types")
54 .Output("handle: variant")
55 .Attr("Toutput_types: list(type) >= 1")
56 .Attr("output_shapes: list(shape) >= 1")
57 .Attr("is_files: bool = false")
58 .Attr("metadata: string = ''")
59 .Attr("replicate_on_split: bool = false")
60 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
61 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
62 "Toutput_types"))
63 .SetForwardTypeFn(full_type::MultiaryUnstack(TFT_DATASET,
64 full_type::UnstackTensor))
65 .SetShapeFn(shape_inference::ScalarShape);
66
67REGISTER_OP("SparseTensorSliceDataset")
68 .Input("indices: int64")
69 .Input("values: Tvalues")
70 .Input("dense_shape: int64")
71 .Output("handle: variant")
72 .Attr("Tvalues: type")
73 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
74 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET, "Tvalues"))
75 .SetShapeFn(shape_inference::ScalarShape);
76
77REGISTER_OP("GeneratorDataset")
78 .Input("init_func_other_args: Tinit_func_args")
79 .Input("next_func_other_args: Tnext_func_args")
80 .Input("finalize_func_other_args: Tfinalize_func_args")
81 .Output("handle: variant")
82 .Attr("init_func: func")
83 .Attr("next_func: func")
84 .Attr("finalize_func: func")
85 .Attr("Tinit_func_args: list(type) >= 0")
86 .Attr("Tnext_func_args: list(type) >= 0")
87 .Attr("Tfinalize_func_args: list(type) >= 0")
88 .Attr("output_types: list(type) >= 1")
89 .Attr("output_shapes: list(shape) >= 1")
90 .Attr("metadata: string = ''")
91 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
92 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
93 "output_types"))
94 .SetShapeFn(shape_inference::ScalarShape);
95
96REGISTER_OP("ZipDataset")
97 .Input("input_datasets: N * variant")
98 .Output("handle: variant")
99 .Attr("output_types: list(type) >= 1")
100 .Attr("output_shapes: list(shape) >= 1")
101 .Attr("N: int >= 1")
102 .Attr("metadata: string = ''")
103 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
104 "output_types"))
105 .SetShapeFn(shape_inference::ScalarShape);
106
107REGISTER_OP("ConcatenateDataset")
108 .Input("input_dataset: variant")
109 .Input("another_dataset: variant")
110 .Output("handle: variant")
111 .Attr("output_types: list(type) >= 1")
112 .Attr("output_shapes: list(shape) >= 1")
113 .Attr("metadata: string = ''")
114 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
115 "output_types"))
116 .SetShapeFn(shape_inference::ScalarShape);
117
118REGISTER_OP("RepeatDataset")
119 .Input("input_dataset: variant")
120 .Input("count: int64")
121 .Output("handle: variant")
122 .Attr("output_types: list(type) >= 1")
123 .Attr("output_shapes: list(shape) >= 1")
124 .Attr("metadata: string = ''")
125 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
126 "output_types"))
127 .SetShapeFn([](shape_inference::InferenceContext* c) {
128 shape_inference::ShapeHandle count_shape;
129 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape));
130 return shape_inference::ScalarShape(c);
131 });
132
133REGISTER_OP("TakeDataset")
134 .Input("input_dataset: variant")
135 .Input("count: int64")
136 .Output("handle: variant")
137 .Attr("output_types: list(type) >= 1")
138 .Attr("output_shapes: list(shape) >= 1")
139 .Attr("metadata: string = ''")
140 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
141 "output_types"))
142 .SetShapeFn([](shape_inference::InferenceContext* c) {
143 shape_inference::ShapeHandle count_shape;
144 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape));
145 return shape_inference::ScalarShape(c);
146 });
147
148REGISTER_OP("SkipDataset")
149 .Input("input_dataset: variant")
150 .Input("count: int64")
151 .Output("handle: variant")
152 .Attr("output_types: list(type) >= 1")
153 .Attr("output_shapes: list(shape) >= 1")
154 .Attr("metadata: string = ''")
155 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
156 "output_types"))
157 .SetShapeFn([](shape_inference::InferenceContext* c) {
158 shape_inference::ShapeHandle count_shape;
159 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape));
160 return shape_inference::ScalarShape(c);
161 });
162
163REGISTER_OP("MapDataset")
164 .Input("input_dataset: variant")
165 .Input("other_arguments: Targuments")
166 .Output("handle: variant")
167 .Attr("f: func")
168 .Attr("Targuments: list(type) >= 0")
169 .Attr("output_types: list(type) >= 1")
170 .Attr("output_shapes: list(shape) >= 1")
171 .Attr("use_inter_op_parallelism: bool = true")
172 .Attr("preserve_cardinality: bool = false")
173 .Attr("metadata: string = ''")
174 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
175 "output_types"))
176 .SetShapeFn(shape_inference::ScalarShape);
177
178REGISTER_OP("ParallelMapDataset")
179 .Input("input_dataset: variant")
180 .Input("other_arguments: Targuments")
181 .Input("num_parallel_calls: int32")
182 .Output("handle: variant")
183 .Attr("f: func")
184 .Attr("Targuments: list(type) >= 0")
185 .Attr("output_types: list(type) >= 1")
186 .Attr("output_shapes: list(shape) >= 1")
187 .Attr("use_inter_op_parallelism: bool = true")
188 .Attr("sloppy: bool = false")
189 .Attr("preserve_cardinality: bool = false")
190 .Attr("metadata: string = ''")
191 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
192 "output_types"))
193 .SetShapeFn(shape_inference::ScalarShape);
194
195REGISTER_OP("ParallelMapDatasetV2")
196 .Input("input_dataset: variant")
197 .Input("other_arguments: Targuments")
198 .Input("num_parallel_calls: int64")
199 .Output("handle: variant")
200 .Attr("f: func")
201 .Attr("Targuments: list(type) >= 0")
202 .Attr("output_types: list(type) >= 1")
203 .Attr("output_shapes: list(shape) >= 1")
204 .Attr("use_inter_op_parallelism: bool = true")
205 // "true", "false", or "default".
206 .Attr("deterministic: string = 'default'")
207 .Attr("preserve_cardinality: bool = false")
208 .Attr("metadata: string = ''")
209 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
210 "output_types"))
211 .SetShapeFn(shape_inference::ScalarShape);
212
213REGISTER_OP("PrefetchDataset")
214 .Input("input_dataset: variant")
215 .Input("buffer_size: int64")
216 .Output("handle: variant")
217 .Attr("output_types: list(type) >= 1")
218 .Attr("output_shapes: list(shape) >= 1")
219 .Attr("slack_period: int = 0")
220 .Attr("legacy_autotune: bool = true")
221 .Attr("buffer_size_min: int = 0")
222 .Attr("metadata: string = ''")
223 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
224 "output_types"))
225 .SetShapeFn([](shape_inference::InferenceContext* c) {
226 shape_inference::ShapeHandle unused;
227 // buffer_size should be a scalar.
228 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
229 return shape_inference::ScalarShape(c);
230 });
231
232REGISTER_OP("FlatMapDataset")
233 .Input("input_dataset: variant")
234 .Input("other_arguments: Targuments")
235 .Output("handle: variant")
236 .Attr("f: func")
237 .Attr("Targuments: list(type) >= 0")
238 .Attr("output_types: list(type) >= 1")
239 .Attr("output_shapes: list(shape) >= 1")
240 .Attr("metadata: string = ''")
241 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
242 "output_types"))
243 .SetShapeFn(shape_inference::ScalarShape);
244
245REGISTER_OP("InterleaveDataset")
246 .Input("input_dataset: variant")
247 .Input("other_arguments: Targuments")
248 .Input("cycle_length: int64")
249 .Input("block_length: int64")
250 .Output("handle: variant")
251 .Attr("f: func")
252 .Attr("Targuments: list(type) >= 0")
253 .Attr("output_types: list(type) >= 1")
254 .Attr("output_shapes: list(shape) >= 1")
255 .Attr("metadata: string = ''")
256 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
257 "output_types"))
258 .SetShapeFn(shape_inference::ScalarShape);
259
260REGISTER_OP("ParallelInterleaveDatasetV2")
261 .Input("input_dataset: variant")
262 .Input("other_arguments: Targuments")
263 .Input("cycle_length: int64")
264 .Input("block_length: int64")
265 .Input("num_parallel_calls: int64")
266 .Output("handle: variant")
267 .Attr("f: func")
268 .Attr("Targuments: list(type) >= 0")
269 .Attr("output_types: list(type) >= 1")
270 .Attr("output_shapes: list(shape) >= 1")
271 .Attr("sloppy: bool = false")
272 .Attr("metadata: string = ''")
273 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
274 "output_types"))
275 .SetShapeFn(shape_inference::ScalarShape);
276
277REGISTER_OP("ParallelInterleaveDatasetV3")
278 .Input("input_dataset: variant")
279 .Input("other_arguments: Targuments")
280 .Input("cycle_length: int64")
281 .Input("block_length: int64")
282 .Input("num_parallel_calls: int64")
283 .Output("handle: variant")
284 .Attr("f: func")
285 // "true", "false", or "default".
286 .Attr("deterministic: string = 'default'")
287 .Attr("Targuments: list(type) >= 0")
288 .Attr("output_types: list(type) >= 1")
289 .Attr("output_shapes: list(shape) >= 1")
290 .Attr("metadata: string = ''")
291 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
292 "output_types"))
293 .SetShapeFn(shape_inference::ScalarShape);
294
295// Like V3, but adds buffer_output_elements and prefetch_input_elements.
296REGISTER_OP("ParallelInterleaveDatasetV4")
297 .Input("input_dataset: variant")
298 .Input("other_arguments: Targuments")
299 .Input("cycle_length: int64")
300 .Input("block_length: int64")
301 .Input("buffer_output_elements: int64")
302 .Input("prefetch_input_elements: int64")
303 .Input("num_parallel_calls: int64")
304 .Output("handle: variant")
305 .Attr("f: func")
306 // "true", "false", or "default".
307 .Attr("deterministic: string = 'default'")
308 .Attr("Targuments: list(type) >= 0")
309 .Attr("output_types: list(type) >= 1")
310 .Attr("output_shapes: list(shape) >= 1")
311 .Attr("metadata: string = ''")
312 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
313 "output_types"))
314 .SetShapeFn(shape_inference::ScalarShape);
315
316REGISTER_OP("FilterDataset")
317 .Input("input_dataset: variant")
318 .Input("other_arguments: Targuments")
319 .Output("handle: variant")
320 .Attr("predicate: func")
321 .Attr("Targuments: list(type) >= 0")
322 .Attr("output_types: list(type) >= 1")
323 .Attr("output_shapes: list(shape) >= 1")
324 .Attr("metadata: string = ''")
325 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
326 "output_types"))
327 .SetShapeFn(shape_inference::ScalarShape);
328
329REGISTER_OP("ParallelFilterDataset")
330 .Input("input_dataset: variant")
331 .Input("other_arguments: Targuments")
332 .Input("num_parallel_calls: int64")
333 .Output("handle: variant")
334 .Attr("predicate: func")
335 // "true", "false", or "default".
336 .Attr("deterministic: string = 'default'")
337 .Attr("Targuments: list(type) >= 0")
338 .Attr("output_types: list(type) >= 1")
339 .Attr("output_shapes: list(shape) >= 1")
340 .Attr("metadata: string = ''")
341 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
342 "output_types"))
343 .SetShapeFn(shape_inference::ScalarShape);
344
345// This op is no longer supported.
346REGISTER_OP("FilterByLastComponentDataset")
347 .Input("input_dataset: variant")
348 .Output("output: variant")
349 .Attr("output_types: list(type) >= 1")
350 .Attr("output_shapes: list(shape) >= 1")
351 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
352 "output_types"))
353 .SetShapeFn(shape_inference::ScalarShape);
354
355REGISTER_OP("WindowDataset")
356 .Input("input_dataset: variant")
357 .Input("size: int64")
358 .Input("shift: int64")
359 .Input("stride: int64")
360 .Input("drop_remainder: bool")
361 .Output("handle: variant")
362 .Attr("output_types: list(type) >= 1")
363 .Attr("output_shapes: list(shape) >= 1")
364 .Attr("metadata: string = ''")
365 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
366 "output_types"))
367 .SetShapeFn([](shape_inference::InferenceContext* c) {
368 shape_inference::ShapeHandle unused;
369 // size, shift, stride, and drop_remainder should be scalars.
370 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
371 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
372 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
373 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
374 return shape_inference::ScalarShape(c);
375 });
376
377REGISTER_OP("WindowOp")
378 .Input("inputs: Tinputs")
379 .Output("handle: variant")
380 .Attr("output_types: list(type) >= 1")
381 .Attr("output_shapes: list(shape) >= 1")
382 .Attr("Tinputs: list(type) >= 1")
383 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
384 "output_types"))
385 .SetShapeFn(shape_inference::ScalarShape);
386
387REGISTER_OP("BatchDataset")
388 .Input("input_dataset: variant")
389 .Input("batch_size: int64")
390 .Output("handle: variant")
391 .Attr("output_types: list(type) >= 1")
392 .Attr("output_shapes: list(shape) >= 1")
393 .Attr("metadata: string = ''")
394 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
395 "output_types"))
396 .SetShapeFn([](shape_inference::InferenceContext* c) {
397 shape_inference::ShapeHandle unused;
398 // batch_size should be a scalar.
399 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
400 return shape_inference::ScalarShape(c);
401 });
402
403REGISTER_OP("BatchDatasetV2")
404 .Input("input_dataset: variant")
405 .Input("batch_size: int64")
406 .Input("drop_remainder: bool")
407 .Output("handle: variant")
408 .Attr("parallel_copy: bool = false")
409 .Attr("output_types: list(type) >= 1")
410 .Attr("output_shapes: list(shape) >= 1")
411 .Attr("metadata: string = ''")
412 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
413 "output_types"))
414 .SetForwardTypeFn(full_type::ContainerMap(TFT_DATASET, /*input_idx=*/0,
415 full_type::BatchTensor))
416 .SetShapeFn([](shape_inference::InferenceContext* c) {
417 shape_inference::ShapeHandle unused;
418 // batch_size should be a scalar.
419 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
420 // drop_remainder should be a scalar.
421 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
422 return shape_inference::ScalarShape(c);
423 });
424
425REGISTER_OP("ParallelBatchDataset")
426 .Input("input_dataset: variant")
427 .Input("batch_size: int64")
428 .Input("num_parallel_calls: int64")
429 .Input("drop_remainder: bool")
430 .Output("handle: variant")
431 .Attr("parallel_copy: bool = false")
432 .Attr("output_types: list(type) >= 1")
433 .Attr("output_shapes: list(shape) >= 1")
434 // "true", "false", or "default".
435 .Attr("deterministic: string = 'default'")
436 .Attr("metadata: string = ''")
437 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
438 "output_types"))
439 .SetShapeFn([](shape_inference::InferenceContext* c) {
440 shape_inference::ShapeHandle unused;
441 // batch_size should be a scalar.
442 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
443 // num_parallel_calls should be a scalar.
444 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
445 // drop_remainder should be a scalar.
446 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
447 return shape_inference::ScalarShape(c);
448 });
449
450REGISTER_OP("ShardDataset")
451 .Input("input_dataset: variant")
452 .Input("num_shards: int64")
453 .Input("index: int64")
454 .Output("handle: variant")
455 .Attr("require_non_empty: bool = false")
456 .Attr("output_types: list(type) >= 1")
457 .Attr("output_shapes: list(shape) >= 1")
458 .Attr("metadata: string = ''")
459 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
460 "output_types"))
461 .SetShapeFn([](shape_inference::InferenceContext* c) {
462 shape_inference::ShapeHandle unused;
463 // num_shards should be a scalar.
464 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
465 // index should be a scalar.
466 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
467 return shape_inference::ScalarShape(c);
468 });
469
470// TODO(mrry): Validate that `padded_shapes` are all vectors, the lengths of
471// `output_types` and `output_shapes` are `N` the `output_shapes` are (as far as
472// possible to tell statically) compatible with `padded_shapes`, and that
473// `padding_values` are all scalars.
474REGISTER_OP("PaddedBatchDataset")
475 .Input("input_dataset: variant")
476 .Input("batch_size: int64")
477 .Input("padded_shapes: N * int64")
478 .Input("padding_values: Toutput_types")
479 .Output("handle: variant")
480 .Attr("Toutput_types: list(type) >= 1")
481 .Attr("output_shapes: list(shape) >= 1")
482 .Attr("N: int >= 1")
483 .Attr("metadata: string = ''")
484 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
485 "Toutput_types"))
486 .SetShapeFn([](shape_inference::InferenceContext* c) {
487 shape_inference::ShapeHandle unused;
488 // batch_size should be a scalar.
489 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
490 return shape_inference::ScalarShape(c);
491 });
492
493REGISTER_OP("PaddedBatchDatasetV2")
494 .Input("input_dataset: variant")
495 .Input("batch_size: int64")
496 .Input("padded_shapes: N * int64")
497 .Input("padding_values: Toutput_types")
498 .Input("drop_remainder: bool")
499 .Output("handle: variant")
500 .Attr("parallel_copy: bool = false")
501 .Attr("Toutput_types: list(type) >= 1")
502 .Attr("output_shapes: list(shape) >= 1")
503 .Attr("N: int >= 1")
504 .Attr("metadata: string = ''")
505 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
506 "Toutput_types"))
507 .SetShapeFn([](shape_inference::InferenceContext* c) {
508 shape_inference::ShapeHandle unused;
509 // batch_size should be a scalar.
510 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
511 // drop_remainder should be a scalar.
512 TF_RETURN_IF_ERROR(
513 c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
514 return shape_inference::ScalarShape(c);
515 });
516
517REGISTER_OP("RangeDataset")
518 .Input("start: int64")
519 .Input("stop: int64")
520 .Input("step: int64")
521 .Output("handle: variant")
522 .Attr("output_types: list(type) >= 1")
523 .Attr("output_shapes: list(shape) >= 1")
524 .Attr("metadata: string = ''")
525 .Attr("replicate_on_split: bool = false")
526 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
527 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
528 "output_types"))
529 .SetShapeFn([](shape_inference::InferenceContext* c) {
530 shape_inference::ShapeHandle unused;
531 // start, stop, and step should be scalars.
532 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
533 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
534 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
535 return shape_inference::ScalarShape(c);
536 });
537
538REGISTER_OP("RewriteDataset")
539 .Input("input_dataset: variant")
540 .Input("rewrite_name: string")
541 .Output("handle: variant")
542 .Attr("output_types: list(type) >= 1")
543 .Attr("output_shapes: list(shape) >= 1")
544 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
545 "output_types"))
546 .SetShapeFn(shape_inference::ScalarShape);
547
548REGISTER_OP("AnonymousSeedGenerator")
549 .Input("seed: int64")
550 .Input("seed2: int64")
551 .Input("reshuffle: bool")
552 .Output("handle: resource")
553 .Output("deleter: variant")
554 .SetShapeFn([](shape_inference::InferenceContext* c) {
555 c->set_output(0, c->Scalar());
556 c->set_output(1, c->Scalar());
557 return OkStatus();
558 });
559
560REGISTER_OP("DatasetCardinality")
561 .Input("input_dataset: variant")
562 .Output("cardinality: int64")
563 .SetShapeFn(shape_inference::ScalarShape);
564
565REGISTER_OP("DeleteSeedGenerator")
566 .Input("handle: resource")
567 .Input("deleter: variant")
568 .SetShapeFn(shape_inference::NoOutputs);
569
570// Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator.
571REGISTER_OP("AnonymousRandomSeedGenerator")
572 .Input("seed: int64")
573 .Input("seed2: int64")
574 .Output("handle: resource")
575 .Output("deleter: variant")
576 .SetShapeFn([](shape_inference::InferenceContext* c) {
577 c->set_output(0, c->Scalar());
578 c->set_output(1, c->Scalar());
579 return OkStatus();
580 });
581
582// Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator.
583REGISTER_OP("DeleteRandomSeedGenerator")
584 .Input("handle: resource")
585 .Input("deleter: variant")
586 .SetShapeFn(shape_inference::NoOutputs);
587
588REGISTER_OP("DummySeedGenerator")
589 .Output("handle: resource")
590 .SetShapeFn([](shape_inference::InferenceContext* c) {
591 c->set_output(0, c->Scalar());
592 return OkStatus();
593 });
594
595REGISTER_OP("ShuffleDataset")
596 .Input("input_dataset: variant")
597 .Input("buffer_size: int64")
598 .Input("seed: int64")
599 .Input("seed2: int64")
600 .Output("handle: variant")
601 .Attr("reshuffle_each_iteration: bool = true")
602 .Attr("output_types: list(type) >= 1")
603 .Attr("output_shapes: list(shape) >= 1")
604 .Attr("metadata: string = ''")
605 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
606 "output_types"))
607 .SetShapeFn([](shape_inference::InferenceContext* c) {
608 shape_inference::ShapeHandle unused;
609 // buffer_size, seed, and seed2 should be scalars.
610 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
611 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
612 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
613 return shape_inference::ScalarShape(c);
614 });
615
616REGISTER_OP("ShuffleDatasetV2")
617 .Input("input_dataset: variant")
618 .Input("buffer_size: int64")
619 .Input("seed_generator: resource")
620 .Output("handle: variant")
621 .Attr("output_types: list(type) >= 1")
622 .Attr("output_shapes: list(shape) >= 1")
623 .Attr("metadata: string = ''")
624 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
625 "output_types"))
626 .SetShapeFn([](shape_inference::InferenceContext* c) {
627 shape_inference::ShapeHandle unused;
628 // buffer_size and seed_generator should be scalars.
629 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
630 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
631 return shape_inference::ScalarShape(c);
632 });
633
634REGISTER_OP("ShuffleDatasetV3")
635 .Input("input_dataset: variant")
636 .Input("buffer_size: int64")
637 .Input("seed: int64")
638 .Input("seed2: int64")
639 .Input("seed_generator: resource")
640 .Output("handle: variant")
641 .Attr("reshuffle_each_iteration: bool = true")
642 .Attr("output_types: list(type) >= 1")
643 .Attr("output_shapes: list(shape) >= 1")
644 .Attr("metadata: string = ''")
645 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
646 "output_types"))
647 .SetShapeFn([](shape_inference::InferenceContext* c) {
648 shape_inference::ShapeHandle unused;
649 // buffer_size, seed, seed2, and seed_generator should be scalars.
650 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
651 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
652 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
653 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
654 return shape_inference::ScalarShape(c);
655 });
656
657REGISTER_OP("ShuffleAndRepeatDataset")
658 .Input("input_dataset: variant")
659 .Input("buffer_size: int64")
660 .Input("seed: int64")
661 .Input("seed2: int64")
662 .Input("count: int64")
663 .Output("handle: variant")
664 .Attr("output_types: list(type) >= 1")
665 .Attr("output_shapes: list(shape) >= 1")
666 .Attr("reshuffle_each_iteration: bool = true")
667 .Attr("metadata: string = ''")
668 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
669 "output_types"))
670 .SetShapeFn([](shape_inference::InferenceContext* c) {
671 shape_inference::ShapeHandle unused;
672 // buffer_size, seed, seed2, and count should be scalars.
673 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
674 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
675 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
676 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
677 return shape_inference::ScalarShape(c);
678 });
679
680REGISTER_OP("ShuffleAndRepeatDatasetV2")
681 .Input("input_dataset: variant")
682 .Input("buffer_size: int64")
683 .Input("seed: int64")
684 .Input("seed2: int64")
685 .Input("count: int64")
686 .Input("seed_generator: resource")
687 .Output("handle: variant")
688 .Attr("reshuffle_each_iteration: bool = true")
689 .Attr("output_types: list(type) >= 1")
690 .Attr("output_shapes: list(shape) >= 1")
691 .Attr("metadata: string = ''")
692 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
693 "output_types"))
694 .SetShapeFn([](shape_inference::InferenceContext* c) {
695 shape_inference::ShapeHandle unused;
696 // buffer_size, seed, seed2, count, and seed_generator should be scalars.
697 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
698 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
699 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
700 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
701 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
702 return shape_inference::ScalarShape(c);
703 });
704
705REGISTER_OP("AnonymousMemoryCache")
706 .Output("handle: resource")
707 .Output("deleter: variant")
708 .SetShapeFn([](shape_inference::InferenceContext* c) {
709 c->set_output(0, c->Scalar());
710 c->set_output(1, c->Scalar());
711 return OkStatus();
712 });
713
714REGISTER_OP("DeleteMemoryCache")
715 .Input("handle: resource")
716 .Input("deleter: variant")
717 .SetShapeFn(shape_inference::NoOutputs);
718
719REGISTER_OP("DummyMemoryCache")
720 .Output("handle: resource")
721 .SetShapeFn([](shape_inference::InferenceContext* c) {
722 c->set_output(0, c->Scalar());
723 return OkStatus();
724 });
725
726REGISTER_OP("CacheDataset")
727 .Input("input_dataset: variant")
728 .Input("filename: string")
729 .Output("handle: variant")
730 .Attr("output_types: list(type) >= 1")
731 .Attr("output_shapes: list(shape) >= 1")
732 .Attr("metadata: string = ''")
733 // TODO(mdan): Should these use type inference instead?
734 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
735 "output_types"))
736 .SetShapeFn([](shape_inference::InferenceContext* c) {
737 shape_inference::ShapeHandle unused;
738 // filename should be a scalar.
739 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
740 return shape_inference::ScalarShape(c);
741 });
742
743REGISTER_OP("CacheDatasetV2")
744 .Input("input_dataset: variant")
745 .Input("filename: string")
746 .Input("cache: resource")
747 .Output("handle: variant")
748 .Attr("output_types: list(type) >= 1")
749 .Attr("output_shapes: list(shape) >= 1")
750 .Attr("metadata: string = ''")
751 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
752 "output_types"))
753 .SetShapeFn([](shape_inference::InferenceContext* c) {
754 shape_inference::ShapeHandle unused;
755 // filename should be a scalar.
756 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
757 // cache should be a scalar.
758 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
759 return shape_inference::ScalarShape(c);
760 });
761
762REGISTER_OP("TextLineDataset")
763 .Input("filenames: string")
764 .Input("compression_type: string")
765 .Input("buffer_size: int64")
766 .Attr("metadata: string = ''")
767 .Output("handle: variant")
768 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
769 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET,
770 TFT_STRING))
771 .SetShapeFn([](shape_inference::InferenceContext* c) {
772 shape_inference::ShapeHandle unused;
773 // `filenames` must be a scalar or a vector.
774 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
775 // `compression_type` could only be a scalar.
776 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
777 // `buffer_size` could only be a scalar.
778 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
779 return shape_inference::ScalarShape(c);
780 });
781
782REGISTER_OP("FixedLengthRecordDataset")
783 .Input("filenames: string")
784 .Input("header_bytes: int64")
785 .Input("record_bytes: int64")
786 .Input("footer_bytes: int64")
787 .Input("buffer_size: int64")
788 .Attr("metadata: string = ''")
789 .Output("handle: variant")
790 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
791 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET,
792 TFT_STRING))
793 .SetShapeFn([](shape_inference::InferenceContext* c) {
794 shape_inference::ShapeHandle unused;
795 // `filenames` must be a scalar or a vector.
796 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
797 // header_bytes, record_bytes, footer_bytes, buffer_size should be
798 // scalars.
799 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
800 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
801 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
802 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
803 return shape_inference::ScalarShape(c);
804 });
805
806REGISTER_OP("FixedLengthRecordDatasetV2")
807 .Input("filenames: string")
808 .Input("header_bytes: int64")
809 .Input("record_bytes: int64")
810 .Input("footer_bytes: int64")
811 .Input("buffer_size: int64")
812 .Input("compression_type: string")
813 .Attr("metadata: string = ''")
814 .Output("handle: variant")
815 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
816 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET,
817 TFT_STRING))
818 .SetShapeFn([](shape_inference::InferenceContext* c) {
819 shape_inference::ShapeHandle unused;
820 // `filenames` must be a scalar or a vector.
821 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
822 // header_bytes, record_bytes, footer_bytes, buffer_size should be
823 // scalars.
824 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
825 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
826 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
827 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
828 return shape_inference::ScalarShape(c);
829 });
830
831REGISTER_OP("TFRecordDataset")
832 .Input("filenames: string")
833 .Input("compression_type: string")
834 .Input("buffer_size: int64")
835 .Attr("metadata: string = ''")
836 .Output("handle: variant")
837 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc.
838 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET,
839 TFT_STRING))
840 .SetShapeFn([](shape_inference::InferenceContext* c) {
841 shape_inference::ShapeHandle unused;
842 // `filenames` must be a scalar or a vector.
843 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
844 // `compression_type` could only be a scalar.
845 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
846 // `buffer_size` could only be a scalar.
847 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
848 return shape_inference::ScalarShape(c);
849 });
850
851REGISTER_OP("Iterator")
852 .Output("handle: resource")
853 .Attr("shared_name: string")
854 .Attr("container: string")
855 .Attr("output_types: list(type) >= 1")
856 .Attr("output_shapes: list(shape) >= 1")
857 .SetShapeFn(shape_inference::ScalarShape);
858
859REGISTER_OP("IteratorV2")
860 .Output("handle: resource")
861 .Attr("shared_name: string")
862 .Attr("container: string")
863 .Attr("output_types: list(type) >= 1")
864 .Attr("output_shapes: list(shape) >= 1")
865 .SetShapeFn(shape_inference::ScalarShape);
866
867REGISTER_OP("AnonymousIterator")
868 .Output("handle: resource")
869 .Attr("output_types: list(type) >= 1")
870 .Attr("output_shapes: list(shape) >= 1")
871 .SetShapeFn(shape_inference::ScalarShape);
872
873REGISTER_OP("AnonymousIteratorV2")
874 .Output("handle: resource")
875 .Output("deleter: variant")
876 .Attr("output_types: list(type) >= 1")
877 .Attr("output_shapes: list(shape) >= 1")
878 .SetShapeFn([](shape_inference::InferenceContext* c) {
879 c->set_output(0, c->Scalar());
880 c->set_output(1, c->Scalar());
881 return OkStatus();
882 });
883
884REGISTER_OP("AnonymousIteratorV3")
885 .Output("handle: resource")
886 .Attr("output_types: list(type) >= 1")
887 .Attr("output_shapes: list(shape) >= 1")
888 .SetShapeFn([](shape_inference::InferenceContext* c) {
889 c->set_output(0, c->Scalar());
890 return OkStatus();
891 });
892
893REGISTER_OP("DeleteIterator")
894 .Input("handle: resource")
895 .Input("deleter: variant")
896 .SetShapeFn(shape_inference::NoOutputs);
897
898REGISTER_OP("DeleteMultiDeviceIterator")
899 .Input("multi_device_iterator: resource")
900 .Input("iterators: N * resource")
901 .Input("deleter: variant")
902 .Attr("N: int >= 0")
903 .SetShapeFn(shape_inference::NoOutputs);
904
905REGISTER_OP("MakeIterator")
906 .Input("dataset: variant")
907 .Input("iterator: resource")
908 .SetTypeConstructor(full_type::NoOutputs())
909 .SetReverseTypeFn(1, full_type::MapCovariant(TFT_DATASET, TFT_ITERATOR, 0))
910 .SetShapeFn(shape_inference::NoOutputs);
911
912REGISTER_OP("OneShotIterator")
913 .Output("handle: resource")
914 .Attr("dataset_factory: func")
915 .Attr("output_types: list(type) >= 1")
916 .Attr("output_shapes: list(shape) >= 1")
917 .Attr("container: string = ''")
918 .Attr("shared_name: string = ''")
919 .SetIsStateful()
920 .SetShapeFn(shape_inference::ScalarShape);
921
922REGISTER_OP("IteratorGetNext")
923 .Input("iterator: resource")
924 .Output("components: output_types")
925 .Attr("output_types: list(type) >= 1")
926 .Attr("output_shapes: list(shape) >= 1")
927 .SetShapeFn(shape_inference::DatasetIteratorShape);
928
929REGISTER_OP("IteratorGetNextSync")
930 .Input("iterator: resource")
931 .Output("components: output_types")
932 .Attr("output_types: list(type) >= 1")
933 .Attr("output_shapes: list(shape) >= 1")
934 .SetShapeFn(shape_inference::DatasetIteratorShape);
935
936// TODO(b/124308596): Instead of conservatively marking this op as stateful,
937// implement a mechanism to determine whether `dataset` has a side-effect
938// and use it to decide whether to use a stateless or stateful version of this
939// op.
940REGISTER_OP("DatasetToSingleElement")
941 .Input("dataset: variant")
942 .Output("components: output_types")
943 .Attr("output_types: list(type) >= 1")
944 .Attr("output_shapes: list(shape) >= 1")
945 .Attr("metadata: string = ''")
946 .SetIsStateful()
947 .SetShapeFn(shape_inference::DatasetIteratorShape);
948
949// TODO(b/124308596): Instead of conservatively marking this op as stateful,
950// implement a mechanism to determine whether `dataset` has a side-effect
951// and use it to decide whether to use a stateless or stateful version of this
952// op.
953REGISTER_OP("ReduceDataset")
954 .Input("input_dataset: variant")
955 .Input("initial_state: Tstate")
956 .Input("other_arguments: Targuments")
957 .Output("components: output_types")
958 .Attr("f: func")
959 .Attr("Tstate: list(type) >= 1")
960 .Attr("Targuments: list(type) >= 0")
961 .Attr("output_types: list(type) >= 1")
962 .Attr("output_shapes: list(shape) >= 1")
963 .Attr("use_inter_op_parallelism: bool = true")
964 .Attr("metadata: string = ''")
965 .SetIsStateful()
966 .SetShapeFn(shape_inference::DatasetIteratorShape);
967
968REGISTER_OP("IteratorToStringHandle")
969 .Input("resource_handle: resource")
970 .Output("string_handle: string")
971 .SetShapeFn(shape_inference::ScalarShape);
972
973REGISTER_OP("IteratorFromStringHandle")
974 .Input("string_handle: string")
975 .Output("resource_handle: resource")
976 .Attr("output_types: list(type) >= 0 = []")
977 .Attr("output_shapes: list(shape) >= 0 = []")
978 .SetShapeFn(shape_inference::ScalarShape);
979
980REGISTER_OP("IteratorFromStringHandleV2")
981 .Input("string_handle: string")
982 .Output("resource_handle: resource")
983 .Attr("output_types: list(type) >= 0 = []")
984 .Attr("output_shapes: list(shape) >= 0 = []")
985 .SetShapeFn(shape_inference::ScalarShape);
986
987REGISTER_OP("SerializeIterator")
988 .Input("resource_handle: resource")
989 .Attr("external_state_policy: int = 0")
990 .Output("serialized: variant")
991 .SetShapeFn([](shape_inference::InferenceContext* c) {
992 c->set_output(0, c->Vector(c->UnknownDim()));
993 return OkStatus();
994 });
995
996REGISTER_OP("DeserializeIterator")
997 .Input("resource_handle: resource")
998 .Input("serialized: variant")
999 .SetShapeFn(shape_inference::NoOutputs);
1000
1001REGISTER_OP("DatasetToGraph")
1002 .Input("input_dataset: variant")
1003 .Attr("stateful_whitelist: list(string) >= 0 = []")
1004 .Attr("allow_stateful: bool = false")
1005 .Attr("strip_device_assignment: bool = false")
1006 .Output("graph: string")
1007 .SetShapeFn(shape_inference::ScalarShape);
1008
1009REGISTER_OP("DatasetToGraphV2")
1010 .Input("input_dataset: variant")
1011 .Attr("external_state_policy: int = 0")
1012 .Attr("strip_device_assignment: bool = false")
1013 .Output("graph: string")
1014 .SetForwardTypeFn(full_type::Encode(TFT_STRING, 0))
1015 .SetShapeFn(shape_inference::ScalarShape);
1016
1017REGISTER_OP("OptimizeDataset")
1018 .Input("input_dataset: variant")
1019 .Input("optimizations: string")
1020 .Output("handle: variant")
1021 .Attr("output_types: list(type) >= 1")
1022 .Attr("output_shapes: list(shape) >= 1")
1023 .Attr("optimization_configs: list(string) = []")
1024 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1025 "output_types"))
1026 .SetShapeFn(shape_inference::ScalarShape);
1027
1028REGISTER_OP("OptimizeDatasetV2")
1029 .Input("input_dataset: variant")
1030 .Input("optimizations_enabled: string")
1031 .Input("optimizations_disabled: string")
1032 .Input("optimizations_default: string")
1033 .Output("handle: variant")
1034 .Attr("output_types: list(type) >= 1")
1035 .Attr("output_shapes: list(shape) >= 1")
1036 .Attr("optimization_configs: list(string) = []")
1037 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1038 "output_types"))
1039 .SetShapeFn(shape_inference::ScalarShape);
1040
1041REGISTER_OP("OptionalFromValue")
1042 .Input("components: Toutput_types")
1043 .Output("optional: variant")
1044 .Attr("Toutput_types: list(type) >= 1")
1045 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_OPTIONAL,
1046 "Toutput_types"))
1047 .SetShapeFn([](shape_inference::InferenceContext* c) {
1048 std::vector<DataType> dtypes;
1049 TF_RETURN_IF_ERROR(c->GetAttr("Toutput_types", &dtypes));
1050 c->set_output(0, c->Scalar());
1051 std::vector<shape_inference::ShapeAndType> shapes_and_types;
1052 shapes_and_types.reserve(c->num_inputs());
1053 const FullTypeDef& ret_types = c->ret_types();
1054 for (int i = 0; i < c->num_inputs(); ++i) {
1055 // TODO(mdan): output_type(i) == optional is incorrect.
1056 // "Optional" is the type of the whole container, not of individual
1057 // elements.
1058 //
1059 // Why ret_types.args(0) and not args(i) --
1060 // For example if Toutput_types is (int32, float32), then
1061 // ret_types.args[0] (i.e. the 0th output) is
1062 // Optional[Record[Tensor[int32, s1], Tensor[float32, s2]]]
1063 // set_output_handle_shapes_and_types tracks the same thing, but in
1064 // a transposed way:
1065 // {ShapeAndType(in32, s1, Optional), ShapeAndType(in32, s2, Optional)}
1066 // That should be corrected in the future (see todo above).
1067 shapes_and_types.emplace_back(c->input(i), dtypes[i],
1068 ret_types.args(0));
1069 }
1070 c->set_output_handle_shapes_and_types(0, shapes_and_types);
1071 return OkStatus();
1072 });
1073
1074REGISTER_OP("OptionalNone")
1075 .Output("optional: variant")
1076 .SetShapeFn(shape_inference::ScalarShape);
1077
1078REGISTER_OP("OptionalHasValue")
1079 .Input("optional: variant")
1080 .Output("has_value: bool")
1081 .SetShapeFn(shape_inference::ScalarShape);
1082
1083REGISTER_OP("OptionalGetValue")
1084 .Input("optional: variant")
1085 .Output("components: output_types")
1086 .Attr("output_types: list(type) >= 1")
1087 .Attr("output_shapes: list(shape) >= 1")
1088 .SetShapeFn(shape_inference::DatasetIteratorShape);
1089
1090REGISTER_OP("IteratorGetNextAsOptional")
1091 .Input("iterator: resource")
1092 .Output("optional: variant")
1093 .Attr("output_types: list(type) >= 1")
1094 .Attr("output_shapes: list(shape) >= 1")
1095 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_OPTIONAL,
1096 "output_types"))
1097 .SetForwardTypeFn(full_type::MapCovariant(TFT_ITERATOR, TFT_OPTIONAL, 0))
1098 .SetShapeFn(shape_inference::ScalarShape);
1099
1100REGISTER_OP("ModelDataset")
1101 .Input("input_dataset: variant")
1102 .Output("handle: variant")
1103 .Attr("algorithm: int = 0")
1104 .Attr("cpu_budget: int = 0")
1105 .Attr("ram_budget: int = 0")
1106 .Attr("output_types: list(type) >= 1")
1107 .Attr("output_shapes: list(shape) >= 1")
1108 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1109 "output_types"))
1110 .SetShapeFn(shape_inference::ScalarShape);
1111
1112// TODO(b/124308749): Add a stateful version of MapDefun and use it when `f`
1113// is stateful.
1114REGISTER_OP("MapDefun")
1115 .Input("arguments: Targuments")
1116 .Input("captured_inputs: Tcaptured")
1117 .Output("output: output_types")
1118 .Attr("Targuments: list(type) >= 1")
1119 .Attr("Tcaptured: list(type) >= 0 = []")
1120 .Attr("output_types: list(type) >= 1")
1121 .Attr("output_shapes: list(shape) >= 1")
1122 .Attr("f: func")
1123 .Attr("max_intra_op_parallelism: int = 1")
1124 .SetShapeFn([](shape_inference::InferenceContext* c) {
1125 std::vector<PartialTensorShape> output_shapes;
1126 TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
1127 DataTypeVector t_args;
1128 TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args));
1129 if (output_shapes.size() != c->num_outputs()) {
1130 return errors::InvalidArgument(
1131 "`output_shapes` must be the same length as `output_types` (",
1132 output_shapes.size(), " vs. ", c->num_outputs(), ")");
1133 }
1134
1135 int64_t dim_zero = -1;
1136 for (size_t i = 0; i < t_args.size(); ++i) {
1137 if (c->Rank(c->input(i)) == 0) {
1138 return errors::InvalidArgument(
1139 "Arguments must have rank at least 1. Input ", i,
1140 " has rank of 0.");
1141 }
1142 auto dim_handle = c->Dim(c->input(i), 0);
1143 if (c->ValueKnown(dim_handle)) {
1144 if (dim_zero == -1) {
1145 dim_zero = c->Value(dim_handle);
1146 } else if (c->Value(dim_handle) != dim_zero) {
1147 return errors::InvalidArgument(
1148 "Arguments must have the same dimension 0.");
1149 }
1150 }
1151 }
1152
1153 for (size_t i = 0; i < output_shapes.size(); ++i) {
1154 PartialTensorShape s({});
1155 s = s.Concatenate(dim_zero);
1156 s = s.Concatenate(output_shapes[i]);
1157 shape_inference::ShapeHandle output_shape_handle;
1158
1159 TF_RETURN_IF_ERROR(
1160 c->MakeShapeFromPartialTensorShape(s, &output_shape_handle));
1161 c->set_output(static_cast<int>(i), output_shape_handle);
1162 }
1163 return OkStatus();
1164 });
1165
1166REGISTER_OP("WrapDatasetVariant")
1167 .Input("input_handle: variant")
1168 .Output("output_handle: variant")
1169 .SetShapeFn(shape_inference::ScalarShape);
1170
1171REGISTER_OP("UnwrapDatasetVariant")
1172 .Input("input_handle: variant")
1173 .Output("output_handle: variant")
1174 .SetShapeFn(shape_inference::ScalarShape);
1175
1176REGISTER_OP("AnonymousMultiDeviceIterator")
1177 .Output("handle: resource")
1178 .Output("deleter: variant")
1179 .Attr("devices: list(string) >= 1")
1180 .Attr("output_types: list(type) >= 1")
1181 .Attr("output_shapes: list(shape) >= 1")
1182 .SetShapeFn([](shape_inference::InferenceContext* c) {
1183 c->set_output(0, c->Scalar());
1184 c->set_output(1, c->Scalar());
1185 return OkStatus();
1186 });
1187
1188REGISTER_OP("AnonymousMultiDeviceIteratorV3")
1189 .Output("handle: resource")
1190 .Attr("devices: list(string) >= 1")
1191 .Attr("output_types: list(type) >= 1")
1192 .Attr("output_shapes: list(shape) >= 1")
1193 .SetShapeFn([](shape_inference::InferenceContext* c) {
1194 c->set_output(0, c->Scalar());
1195 return OkStatus();
1196 });
1197
1198REGISTER_OP("MultiDeviceIterator")
1199 .Output("handle: resource")
1200 .Attr("devices: list(string) >= 1")
1201 .Attr("shared_name: string")
1202 .Attr("container: string")
1203 .Attr("output_types: list(type) >= 1")
1204 .Attr("output_shapes: list(shape) >= 1")
1205 .SetShapeFn(shape_inference::ScalarShape);
1206
1207REGISTER_OP("MultiDeviceIteratorInit")
1208 .Input("dataset: variant")
1209 .Input("multi_device_iterator: resource")
1210 .Input("max_buffer_size: int64")
1211 .Output("incarnation_id: int64")
1212 .SetShapeFn(shape_inference::ScalarShape);
1213
1214REGISTER_OP("MultiDeviceIteratorGetNextFromShard")
1215 .Input("multi_device_iterator: resource")
1216 .Input("shard_num: int32")
1217 .Input("incarnation_id: int64")
1218 .Output("components: output_types")
1219 .Attr("output_types: list(type) >= 1")
1220 .Attr("output_shapes: list(shape) >= 1")
1221 .SetShapeFn(shape_inference::DatasetIteratorShape);
1222
1223REGISTER_OP("MultiDeviceIteratorToStringHandle")
1224 .Input("multi_device_iterator: resource")
1225 .Output("string_handle: string")
1226 .SetForwardTypeFn(full_type::Encode(TFT_STRING, 0))
1227 .SetShapeFn(shape_inference::ScalarShape);
1228
1229REGISTER_OP("MultiDeviceIteratorFromStringHandle")
1230 .Input("string_handle: string")
1231 .Output("multi_device_iterator: resource")
1232 .Attr("output_types: list(type) >= 0 = []")
1233 .Attr("output_shapes: list(shape) >= 0 = []")
1234 .SetForwardTypeFn(full_type::Decode(TFT_STRING, 0))
1235 .SetShapeFn(shape_inference::ScalarShape);
1236
1237REGISTER_OP("OptionsDataset")
1238 .Input("input_dataset: variant")
1239 .Output("handle: variant")
1240 .Attr("serialized_options: string")
1241 .Attr("output_types: list(type) >= 1")
1242 .Attr("output_shapes: list(shape) >= 1")
1243 .Attr("metadata: string = ''")
1244 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1245 "output_types"))
1246 .SetShapeFn(shape_inference::ScalarShape);
1247
1248REGISTER_OP("GetOptions")
1249 .Input("input_dataset: variant")
1250 .Output("serialized_options: string")
1251 .SetShapeFn(shape_inference::ScalarShape);
1252
1253REGISTER_OP("FinalizeDataset")
1254 .Input("input_dataset: variant")
1255 .Output("handle: variant")
1256 .Attr("has_captured_ref: bool = false")
1257 .Attr("output_types: list(type) >= 1")
1258 .Attr("output_shapes: list(shape) >= 1")
1259 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1260 "output_types"))
1261 .SetShapeFn(shape_inference::ScalarShape);
1262
1263} // namespace tensorflow
1264