1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/framework/common_shape_fns.h" |
16 | #include "tensorflow/core/framework/full_type.pb.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/op_def_builder.h" |
19 | #include "tensorflow/core/framework/shape_inference.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | // -------------------------------------------------------------------------- |
24 | |
25 | // The ops in this section can be composed to define an input |
26 | // pipeline. Each op produces a DT_VARIANT tensor that represents |
27 | // a DAG of "dataset" objects. An "dataset" object can be converted |
28 | // to a stateful "iterator" by passing the "dataset" to the |
29 | // "MakeIterator" op. |
30 | // |
31 | // TODO(b/123753214): DT_VARIANT tensors that represent "dataset" objects are |
32 | // not presently serializable. To avoid issues with graph optimizations, such |
33 | // as constant folding, CSE, or DCE, ensure that any "source dataset" ops |
34 | // (i.e. ops that output a dataset and do not take one as input) are |
35 | // marked as "do not optimize". |
36 | |
37 | // TODO(mrry): Validate that `components` have shapes compatible with |
38 | // `output_shapes`. |
39 | REGISTER_OP("TensorDataset" ) |
40 | .Input("components: Toutput_types" ) |
41 | .Output("handle: variant" ) |
42 | .Attr("Toutput_types: list(type) >= 1" ) |
43 | .Attr("output_shapes: list(shape) >= 1" ) |
44 | .Attr("metadata: string = ''" ) |
45 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
46 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
47 | "Toutput_types" )) |
48 | .SetShapeFn(shape_inference::ScalarShape); |
49 | |
50 | // TODO(mrry): Validate that the dim-0 slices of `components` have shapes |
51 | // compatible with `output_shapes`. |
52 | REGISTER_OP("TensorSliceDataset" ) |
53 | .Input("components: Toutput_types" ) |
54 | .Output("handle: variant" ) |
55 | .Attr("Toutput_types: list(type) >= 1" ) |
56 | .Attr("output_shapes: list(shape) >= 1" ) |
57 | .Attr("is_files: bool = false" ) |
58 | .Attr("metadata: string = ''" ) |
59 | .Attr("replicate_on_split: bool = false" ) |
60 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
61 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
62 | "Toutput_types" )) |
63 | .SetForwardTypeFn(full_type::MultiaryUnstack(TFT_DATASET, |
64 | full_type::UnstackTensor)) |
65 | .SetShapeFn(shape_inference::ScalarShape); |
66 | |
67 | REGISTER_OP("SparseTensorSliceDataset" ) |
68 | .Input("indices: int64" ) |
69 | .Input("values: Tvalues" ) |
70 | .Input("dense_shape: int64" ) |
71 | .Output("handle: variant" ) |
72 | .Attr("Tvalues: type" ) |
73 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
74 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET, "Tvalues" )) |
75 | .SetShapeFn(shape_inference::ScalarShape); |
76 | |
77 | REGISTER_OP("GeneratorDataset" ) |
78 | .Input("init_func_other_args: Tinit_func_args" ) |
79 | .Input("next_func_other_args: Tnext_func_args" ) |
80 | .Input("finalize_func_other_args: Tfinalize_func_args" ) |
81 | .Output("handle: variant" ) |
82 | .Attr("init_func: func" ) |
83 | .Attr("next_func: func" ) |
84 | .Attr("finalize_func: func" ) |
85 | .Attr("Tinit_func_args: list(type) >= 0" ) |
86 | .Attr("Tnext_func_args: list(type) >= 0" ) |
87 | .Attr("Tfinalize_func_args: list(type) >= 0" ) |
88 | .Attr("output_types: list(type) >= 1" ) |
89 | .Attr("output_shapes: list(shape) >= 1" ) |
90 | .Attr("metadata: string = ''" ) |
91 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
92 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
93 | "output_types" )) |
94 | .SetShapeFn(shape_inference::ScalarShape); |
95 | |
96 | REGISTER_OP("ZipDataset" ) |
97 | .Input("input_datasets: N * variant" ) |
98 | .Output("handle: variant" ) |
99 | .Attr("output_types: list(type) >= 1" ) |
100 | .Attr("output_shapes: list(shape) >= 1" ) |
101 | .Attr("N: int >= 1" ) |
102 | .Attr("metadata: string = ''" ) |
103 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
104 | "output_types" )) |
105 | .SetShapeFn(shape_inference::ScalarShape); |
106 | |
107 | REGISTER_OP("ConcatenateDataset" ) |
108 | .Input("input_dataset: variant" ) |
109 | .Input("another_dataset: variant" ) |
110 | .Output("handle: variant" ) |
111 | .Attr("output_types: list(type) >= 1" ) |
112 | .Attr("output_shapes: list(shape) >= 1" ) |
113 | .Attr("metadata: string = ''" ) |
114 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
115 | "output_types" )) |
116 | .SetShapeFn(shape_inference::ScalarShape); |
117 | |
118 | REGISTER_OP("RepeatDataset" ) |
119 | .Input("input_dataset: variant" ) |
120 | .Input("count: int64" ) |
121 | .Output("handle: variant" ) |
122 | .Attr("output_types: list(type) >= 1" ) |
123 | .Attr("output_shapes: list(shape) >= 1" ) |
124 | .Attr("metadata: string = ''" ) |
125 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
126 | "output_types" )) |
127 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
128 | shape_inference::ShapeHandle count_shape; |
129 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape)); |
130 | return shape_inference::ScalarShape(c); |
131 | }); |
132 | |
133 | REGISTER_OP("TakeDataset" ) |
134 | .Input("input_dataset: variant" ) |
135 | .Input("count: int64" ) |
136 | .Output("handle: variant" ) |
137 | .Attr("output_types: list(type) >= 1" ) |
138 | .Attr("output_shapes: list(shape) >= 1" ) |
139 | .Attr("metadata: string = ''" ) |
140 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
141 | "output_types" )) |
142 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
143 | shape_inference::ShapeHandle count_shape; |
144 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape)); |
145 | return shape_inference::ScalarShape(c); |
146 | }); |
147 | |
148 | REGISTER_OP("SkipDataset" ) |
149 | .Input("input_dataset: variant" ) |
150 | .Input("count: int64" ) |
151 | .Output("handle: variant" ) |
152 | .Attr("output_types: list(type) >= 1" ) |
153 | .Attr("output_shapes: list(shape) >= 1" ) |
154 | .Attr("metadata: string = ''" ) |
155 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
156 | "output_types" )) |
157 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
158 | shape_inference::ShapeHandle count_shape; |
159 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape)); |
160 | return shape_inference::ScalarShape(c); |
161 | }); |
162 | |
163 | REGISTER_OP("MapDataset" ) |
164 | .Input("input_dataset: variant" ) |
165 | .Input("other_arguments: Targuments" ) |
166 | .Output("handle: variant" ) |
167 | .Attr("f: func" ) |
168 | .Attr("Targuments: list(type) >= 0" ) |
169 | .Attr("output_types: list(type) >= 1" ) |
170 | .Attr("output_shapes: list(shape) >= 1" ) |
171 | .Attr("use_inter_op_parallelism: bool = true" ) |
172 | .Attr("preserve_cardinality: bool = false" ) |
173 | .Attr("metadata: string = ''" ) |
174 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
175 | "output_types" )) |
176 | .SetShapeFn(shape_inference::ScalarShape); |
177 | |
178 | REGISTER_OP("ParallelMapDataset" ) |
179 | .Input("input_dataset: variant" ) |
180 | .Input("other_arguments: Targuments" ) |
181 | .Input("num_parallel_calls: int32" ) |
182 | .Output("handle: variant" ) |
183 | .Attr("f: func" ) |
184 | .Attr("Targuments: list(type) >= 0" ) |
185 | .Attr("output_types: list(type) >= 1" ) |
186 | .Attr("output_shapes: list(shape) >= 1" ) |
187 | .Attr("use_inter_op_parallelism: bool = true" ) |
188 | .Attr("sloppy: bool = false" ) |
189 | .Attr("preserve_cardinality: bool = false" ) |
190 | .Attr("metadata: string = ''" ) |
191 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
192 | "output_types" )) |
193 | .SetShapeFn(shape_inference::ScalarShape); |
194 | |
195 | REGISTER_OP("ParallelMapDatasetV2" ) |
196 | .Input("input_dataset: variant" ) |
197 | .Input("other_arguments: Targuments" ) |
198 | .Input("num_parallel_calls: int64" ) |
199 | .Output("handle: variant" ) |
200 | .Attr("f: func" ) |
201 | .Attr("Targuments: list(type) >= 0" ) |
202 | .Attr("output_types: list(type) >= 1" ) |
203 | .Attr("output_shapes: list(shape) >= 1" ) |
204 | .Attr("use_inter_op_parallelism: bool = true" ) |
205 | // "true", "false", or "default". |
206 | .Attr("deterministic: string = 'default'" ) |
207 | .Attr("preserve_cardinality: bool = false" ) |
208 | .Attr("metadata: string = ''" ) |
209 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
210 | "output_types" )) |
211 | .SetShapeFn(shape_inference::ScalarShape); |
212 | |
213 | REGISTER_OP("PrefetchDataset" ) |
214 | .Input("input_dataset: variant" ) |
215 | .Input("buffer_size: int64" ) |
216 | .Output("handle: variant" ) |
217 | .Attr("output_types: list(type) >= 1" ) |
218 | .Attr("output_shapes: list(shape) >= 1" ) |
219 | .Attr("slack_period: int = 0" ) |
220 | .Attr("legacy_autotune: bool = true" ) |
221 | .Attr("buffer_size_min: int = 0" ) |
222 | .Attr("metadata: string = ''" ) |
223 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
224 | "output_types" )) |
225 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
226 | shape_inference::ShapeHandle unused; |
227 | // buffer_size should be a scalar. |
228 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
229 | return shape_inference::ScalarShape(c); |
230 | }); |
231 | |
232 | REGISTER_OP("FlatMapDataset" ) |
233 | .Input("input_dataset: variant" ) |
234 | .Input("other_arguments: Targuments" ) |
235 | .Output("handle: variant" ) |
236 | .Attr("f: func" ) |
237 | .Attr("Targuments: list(type) >= 0" ) |
238 | .Attr("output_types: list(type) >= 1" ) |
239 | .Attr("output_shapes: list(shape) >= 1" ) |
240 | .Attr("metadata: string = ''" ) |
241 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
242 | "output_types" )) |
243 | .SetShapeFn(shape_inference::ScalarShape); |
244 | |
245 | REGISTER_OP("InterleaveDataset" ) |
246 | .Input("input_dataset: variant" ) |
247 | .Input("other_arguments: Targuments" ) |
248 | .Input("cycle_length: int64" ) |
249 | .Input("block_length: int64" ) |
250 | .Output("handle: variant" ) |
251 | .Attr("f: func" ) |
252 | .Attr("Targuments: list(type) >= 0" ) |
253 | .Attr("output_types: list(type) >= 1" ) |
254 | .Attr("output_shapes: list(shape) >= 1" ) |
255 | .Attr("metadata: string = ''" ) |
256 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
257 | "output_types" )) |
258 | .SetShapeFn(shape_inference::ScalarShape); |
259 | |
260 | REGISTER_OP("ParallelInterleaveDatasetV2" ) |
261 | .Input("input_dataset: variant" ) |
262 | .Input("other_arguments: Targuments" ) |
263 | .Input("cycle_length: int64" ) |
264 | .Input("block_length: int64" ) |
265 | .Input("num_parallel_calls: int64" ) |
266 | .Output("handle: variant" ) |
267 | .Attr("f: func" ) |
268 | .Attr("Targuments: list(type) >= 0" ) |
269 | .Attr("output_types: list(type) >= 1" ) |
270 | .Attr("output_shapes: list(shape) >= 1" ) |
271 | .Attr("sloppy: bool = false" ) |
272 | .Attr("metadata: string = ''" ) |
273 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
274 | "output_types" )) |
275 | .SetShapeFn(shape_inference::ScalarShape); |
276 | |
277 | REGISTER_OP("ParallelInterleaveDatasetV3" ) |
278 | .Input("input_dataset: variant" ) |
279 | .Input("other_arguments: Targuments" ) |
280 | .Input("cycle_length: int64" ) |
281 | .Input("block_length: int64" ) |
282 | .Input("num_parallel_calls: int64" ) |
283 | .Output("handle: variant" ) |
284 | .Attr("f: func" ) |
285 | // "true", "false", or "default". |
286 | .Attr("deterministic: string = 'default'" ) |
287 | .Attr("Targuments: list(type) >= 0" ) |
288 | .Attr("output_types: list(type) >= 1" ) |
289 | .Attr("output_shapes: list(shape) >= 1" ) |
290 | .Attr("metadata: string = ''" ) |
291 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
292 | "output_types" )) |
293 | .SetShapeFn(shape_inference::ScalarShape); |
294 | |
295 | // Like V3, but adds buffer_output_elements and prefetch_input_elements. |
296 | REGISTER_OP("ParallelInterleaveDatasetV4" ) |
297 | .Input("input_dataset: variant" ) |
298 | .Input("other_arguments: Targuments" ) |
299 | .Input("cycle_length: int64" ) |
300 | .Input("block_length: int64" ) |
301 | .Input("buffer_output_elements: int64" ) |
302 | .Input("prefetch_input_elements: int64" ) |
303 | .Input("num_parallel_calls: int64" ) |
304 | .Output("handle: variant" ) |
305 | .Attr("f: func" ) |
306 | // "true", "false", or "default". |
307 | .Attr("deterministic: string = 'default'" ) |
308 | .Attr("Targuments: list(type) >= 0" ) |
309 | .Attr("output_types: list(type) >= 1" ) |
310 | .Attr("output_shapes: list(shape) >= 1" ) |
311 | .Attr("metadata: string = ''" ) |
312 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
313 | "output_types" )) |
314 | .SetShapeFn(shape_inference::ScalarShape); |
315 | |
316 | REGISTER_OP("FilterDataset" ) |
317 | .Input("input_dataset: variant" ) |
318 | .Input("other_arguments: Targuments" ) |
319 | .Output("handle: variant" ) |
320 | .Attr("predicate: func" ) |
321 | .Attr("Targuments: list(type) >= 0" ) |
322 | .Attr("output_types: list(type) >= 1" ) |
323 | .Attr("output_shapes: list(shape) >= 1" ) |
324 | .Attr("metadata: string = ''" ) |
325 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
326 | "output_types" )) |
327 | .SetShapeFn(shape_inference::ScalarShape); |
328 | |
329 | REGISTER_OP("ParallelFilterDataset" ) |
330 | .Input("input_dataset: variant" ) |
331 | .Input("other_arguments: Targuments" ) |
332 | .Input("num_parallel_calls: int64" ) |
333 | .Output("handle: variant" ) |
334 | .Attr("predicate: func" ) |
335 | // "true", "false", or "default". |
336 | .Attr("deterministic: string = 'default'" ) |
337 | .Attr("Targuments: list(type) >= 0" ) |
338 | .Attr("output_types: list(type) >= 1" ) |
339 | .Attr("output_shapes: list(shape) >= 1" ) |
340 | .Attr("metadata: string = ''" ) |
341 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
342 | "output_types" )) |
343 | .SetShapeFn(shape_inference::ScalarShape); |
344 | |
345 | // This op is no longer supported. |
346 | REGISTER_OP("FilterByLastComponentDataset" ) |
347 | .Input("input_dataset: variant" ) |
348 | .Output("output: variant" ) |
349 | .Attr("output_types: list(type) >= 1" ) |
350 | .Attr("output_shapes: list(shape) >= 1" ) |
351 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
352 | "output_types" )) |
353 | .SetShapeFn(shape_inference::ScalarShape); |
354 | |
355 | REGISTER_OP("WindowDataset" ) |
356 | .Input("input_dataset: variant" ) |
357 | .Input("size: int64" ) |
358 | .Input("shift: int64" ) |
359 | .Input("stride: int64" ) |
360 | .Input("drop_remainder: bool" ) |
361 | .Output("handle: variant" ) |
362 | .Attr("output_types: list(type) >= 1" ) |
363 | .Attr("output_shapes: list(shape) >= 1" ) |
364 | .Attr("metadata: string = ''" ) |
365 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
366 | "output_types" )) |
367 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
368 | shape_inference::ShapeHandle unused; |
369 | // size, shift, stride, and drop_remainder should be scalars. |
370 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
371 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
372 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
373 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
374 | return shape_inference::ScalarShape(c); |
375 | }); |
376 | |
377 | REGISTER_OP("WindowOp" ) |
378 | .Input("inputs: Tinputs" ) |
379 | .Output("handle: variant" ) |
380 | .Attr("output_types: list(type) >= 1" ) |
381 | .Attr("output_shapes: list(shape) >= 1" ) |
382 | .Attr("Tinputs: list(type) >= 1" ) |
383 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
384 | "output_types" )) |
385 | .SetShapeFn(shape_inference::ScalarShape); |
386 | |
387 | REGISTER_OP("BatchDataset" ) |
388 | .Input("input_dataset: variant" ) |
389 | .Input("batch_size: int64" ) |
390 | .Output("handle: variant" ) |
391 | .Attr("output_types: list(type) >= 1" ) |
392 | .Attr("output_shapes: list(shape) >= 1" ) |
393 | .Attr("metadata: string = ''" ) |
394 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
395 | "output_types" )) |
396 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
397 | shape_inference::ShapeHandle unused; |
398 | // batch_size should be a scalar. |
399 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
400 | return shape_inference::ScalarShape(c); |
401 | }); |
402 | |
403 | REGISTER_OP("BatchDatasetV2" ) |
404 | .Input("input_dataset: variant" ) |
405 | .Input("batch_size: int64" ) |
406 | .Input("drop_remainder: bool" ) |
407 | .Output("handle: variant" ) |
408 | .Attr("parallel_copy: bool = false" ) |
409 | .Attr("output_types: list(type) >= 1" ) |
410 | .Attr("output_shapes: list(shape) >= 1" ) |
411 | .Attr("metadata: string = ''" ) |
412 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
413 | "output_types" )) |
414 | .SetForwardTypeFn(full_type::ContainerMap(TFT_DATASET, /*input_idx=*/0, |
415 | full_type::BatchTensor)) |
416 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
417 | shape_inference::ShapeHandle unused; |
418 | // batch_size should be a scalar. |
419 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
420 | // drop_remainder should be a scalar. |
421 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
422 | return shape_inference::ScalarShape(c); |
423 | }); |
424 | |
425 | REGISTER_OP("ParallelBatchDataset" ) |
426 | .Input("input_dataset: variant" ) |
427 | .Input("batch_size: int64" ) |
428 | .Input("num_parallel_calls: int64" ) |
429 | .Input("drop_remainder: bool" ) |
430 | .Output("handle: variant" ) |
431 | .Attr("parallel_copy: bool = false" ) |
432 | .Attr("output_types: list(type) >= 1" ) |
433 | .Attr("output_shapes: list(shape) >= 1" ) |
434 | // "true", "false", or "default". |
435 | .Attr("deterministic: string = 'default'" ) |
436 | .Attr("metadata: string = ''" ) |
437 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
438 | "output_types" )) |
439 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
440 | shape_inference::ShapeHandle unused; |
441 | // batch_size should be a scalar. |
442 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
443 | // num_parallel_calls should be a scalar. |
444 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
445 | // drop_remainder should be a scalar. |
446 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
447 | return shape_inference::ScalarShape(c); |
448 | }); |
449 | |
450 | REGISTER_OP("ShardDataset" ) |
451 | .Input("input_dataset: variant" ) |
452 | .Input("num_shards: int64" ) |
453 | .Input("index: int64" ) |
454 | .Output("handle: variant" ) |
455 | .Attr("require_non_empty: bool = false" ) |
456 | .Attr("output_types: list(type) >= 1" ) |
457 | .Attr("output_shapes: list(shape) >= 1" ) |
458 | .Attr("metadata: string = ''" ) |
459 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
460 | "output_types" )) |
461 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
462 | shape_inference::ShapeHandle unused; |
463 | // num_shards should be a scalar. |
464 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
465 | // index should be a scalar. |
466 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
467 | return shape_inference::ScalarShape(c); |
468 | }); |
469 | |
470 | // TODO(mrry): Validate that `padded_shapes` are all vectors, the lengths of |
471 | // `output_types` and `output_shapes` are `N` the `output_shapes` are (as far as |
472 | // possible to tell statically) compatible with `padded_shapes`, and that |
473 | // `padding_values` are all scalars. |
474 | REGISTER_OP("PaddedBatchDataset" ) |
475 | .Input("input_dataset: variant" ) |
476 | .Input("batch_size: int64" ) |
477 | .Input("padded_shapes: N * int64" ) |
478 | .Input("padding_values: Toutput_types" ) |
479 | .Output("handle: variant" ) |
480 | .Attr("Toutput_types: list(type) >= 1" ) |
481 | .Attr("output_shapes: list(shape) >= 1" ) |
482 | .Attr("N: int >= 1" ) |
483 | .Attr("metadata: string = ''" ) |
484 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
485 | "Toutput_types" )) |
486 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
487 | shape_inference::ShapeHandle unused; |
488 | // batch_size should be a scalar. |
489 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
490 | return shape_inference::ScalarShape(c); |
491 | }); |
492 | |
493 | REGISTER_OP("PaddedBatchDatasetV2" ) |
494 | .Input("input_dataset: variant" ) |
495 | .Input("batch_size: int64" ) |
496 | .Input("padded_shapes: N * int64" ) |
497 | .Input("padding_values: Toutput_types" ) |
498 | .Input("drop_remainder: bool" ) |
499 | .Output("handle: variant" ) |
500 | .Attr("parallel_copy: bool = false" ) |
501 | .Attr("Toutput_types: list(type) >= 1" ) |
502 | .Attr("output_shapes: list(shape) >= 1" ) |
503 | .Attr("N: int >= 1" ) |
504 | .Attr("metadata: string = ''" ) |
505 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
506 | "Toutput_types" )) |
507 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
508 | shape_inference::ShapeHandle unused; |
509 | // batch_size should be a scalar. |
510 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
511 | // drop_remainder should be a scalar. |
512 | TF_RETURN_IF_ERROR( |
513 | c->WithRank(c->input(c->num_inputs() - 1), 0, &unused)); |
514 | return shape_inference::ScalarShape(c); |
515 | }); |
516 | |
517 | REGISTER_OP("RangeDataset" ) |
518 | .Input("start: int64" ) |
519 | .Input("stop: int64" ) |
520 | .Input("step: int64" ) |
521 | .Output("handle: variant" ) |
522 | .Attr("output_types: list(type) >= 1" ) |
523 | .Attr("output_shapes: list(shape) >= 1" ) |
524 | .Attr("metadata: string = ''" ) |
525 | .Attr("replicate_on_split: bool = false" ) |
526 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
527 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
528 | "output_types" )) |
529 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
530 | shape_inference::ShapeHandle unused; |
531 | // start, stop, and step should be scalars. |
532 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
533 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
534 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
535 | return shape_inference::ScalarShape(c); |
536 | }); |
537 | |
538 | REGISTER_OP("RewriteDataset" ) |
539 | .Input("input_dataset: variant" ) |
540 | .Input("rewrite_name: string" ) |
541 | .Output("handle: variant" ) |
542 | .Attr("output_types: list(type) >= 1" ) |
543 | .Attr("output_shapes: list(shape) >= 1" ) |
544 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
545 | "output_types" )) |
546 | .SetShapeFn(shape_inference::ScalarShape); |
547 | |
548 | REGISTER_OP("AnonymousSeedGenerator" ) |
549 | .Input("seed: int64" ) |
550 | .Input("seed2: int64" ) |
551 | .Input("reshuffle: bool" ) |
552 | .Output("handle: resource" ) |
553 | .Output("deleter: variant" ) |
554 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
555 | c->set_output(0, c->Scalar()); |
556 | c->set_output(1, c->Scalar()); |
557 | return OkStatus(); |
558 | }); |
559 | |
560 | REGISTER_OP("DatasetCardinality" ) |
561 | .Input("input_dataset: variant" ) |
562 | .Output("cardinality: int64" ) |
563 | .SetShapeFn(shape_inference::ScalarShape); |
564 | |
565 | REGISTER_OP("DeleteSeedGenerator" ) |
566 | .Input("handle: resource" ) |
567 | .Input("deleter: variant" ) |
568 | .SetShapeFn(shape_inference::NoOutputs); |
569 | |
570 | // Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator. |
571 | REGISTER_OP("AnonymousRandomSeedGenerator" ) |
572 | .Input("seed: int64" ) |
573 | .Input("seed2: int64" ) |
574 | .Output("handle: resource" ) |
575 | .Output("deleter: variant" ) |
576 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
577 | c->set_output(0, c->Scalar()); |
578 | c->set_output(1, c->Scalar()); |
579 | return OkStatus(); |
580 | }); |
581 | |
582 | // Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator. |
583 | REGISTER_OP("DeleteRandomSeedGenerator" ) |
584 | .Input("handle: resource" ) |
585 | .Input("deleter: variant" ) |
586 | .SetShapeFn(shape_inference::NoOutputs); |
587 | |
588 | REGISTER_OP("DummySeedGenerator" ) |
589 | .Output("handle: resource" ) |
590 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
591 | c->set_output(0, c->Scalar()); |
592 | return OkStatus(); |
593 | }); |
594 | |
595 | REGISTER_OP("ShuffleDataset" ) |
596 | .Input("input_dataset: variant" ) |
597 | .Input("buffer_size: int64" ) |
598 | .Input("seed: int64" ) |
599 | .Input("seed2: int64" ) |
600 | .Output("handle: variant" ) |
601 | .Attr("reshuffle_each_iteration: bool = true" ) |
602 | .Attr("output_types: list(type) >= 1" ) |
603 | .Attr("output_shapes: list(shape) >= 1" ) |
604 | .Attr("metadata: string = ''" ) |
605 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
606 | "output_types" )) |
607 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
608 | shape_inference::ShapeHandle unused; |
609 | // buffer_size, seed, and seed2 should be scalars. |
610 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
611 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
612 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
613 | return shape_inference::ScalarShape(c); |
614 | }); |
615 | |
616 | REGISTER_OP("ShuffleDatasetV2" ) |
617 | .Input("input_dataset: variant" ) |
618 | .Input("buffer_size: int64" ) |
619 | .Input("seed_generator: resource" ) |
620 | .Output("handle: variant" ) |
621 | .Attr("output_types: list(type) >= 1" ) |
622 | .Attr("output_shapes: list(shape) >= 1" ) |
623 | .Attr("metadata: string = ''" ) |
624 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
625 | "output_types" )) |
626 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
627 | shape_inference::ShapeHandle unused; |
628 | // buffer_size and seed_generator should be scalars. |
629 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
630 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
631 | return shape_inference::ScalarShape(c); |
632 | }); |
633 | |
634 | REGISTER_OP("ShuffleDatasetV3" ) |
635 | .Input("input_dataset: variant" ) |
636 | .Input("buffer_size: int64" ) |
637 | .Input("seed: int64" ) |
638 | .Input("seed2: int64" ) |
639 | .Input("seed_generator: resource" ) |
640 | .Output("handle: variant" ) |
641 | .Attr("reshuffle_each_iteration: bool = true" ) |
642 | .Attr("output_types: list(type) >= 1" ) |
643 | .Attr("output_shapes: list(shape) >= 1" ) |
644 | .Attr("metadata: string = ''" ) |
645 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
646 | "output_types" )) |
647 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
648 | shape_inference::ShapeHandle unused; |
649 | // buffer_size, seed, seed2, and seed_generator should be scalars. |
650 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
651 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
652 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
653 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
654 | return shape_inference::ScalarShape(c); |
655 | }); |
656 | |
657 | REGISTER_OP("ShuffleAndRepeatDataset" ) |
658 | .Input("input_dataset: variant" ) |
659 | .Input("buffer_size: int64" ) |
660 | .Input("seed: int64" ) |
661 | .Input("seed2: int64" ) |
662 | .Input("count: int64" ) |
663 | .Output("handle: variant" ) |
664 | .Attr("output_types: list(type) >= 1" ) |
665 | .Attr("output_shapes: list(shape) >= 1" ) |
666 | .Attr("reshuffle_each_iteration: bool = true" ) |
667 | .Attr("metadata: string = ''" ) |
668 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
669 | "output_types" )) |
670 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
671 | shape_inference::ShapeHandle unused; |
672 | // buffer_size, seed, seed2, and count should be scalars. |
673 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
674 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
675 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
676 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
677 | return shape_inference::ScalarShape(c); |
678 | }); |
679 | |
680 | REGISTER_OP("ShuffleAndRepeatDatasetV2" ) |
681 | .Input("input_dataset: variant" ) |
682 | .Input("buffer_size: int64" ) |
683 | .Input("seed: int64" ) |
684 | .Input("seed2: int64" ) |
685 | .Input("count: int64" ) |
686 | .Input("seed_generator: resource" ) |
687 | .Output("handle: variant" ) |
688 | .Attr("reshuffle_each_iteration: bool = true" ) |
689 | .Attr("output_types: list(type) >= 1" ) |
690 | .Attr("output_shapes: list(shape) >= 1" ) |
691 | .Attr("metadata: string = ''" ) |
692 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
693 | "output_types" )) |
694 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
695 | shape_inference::ShapeHandle unused; |
696 | // buffer_size, seed, seed2, count, and seed_generator should be scalars. |
697 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
698 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
699 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
700 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
701 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
702 | return shape_inference::ScalarShape(c); |
703 | }); |
704 | |
705 | REGISTER_OP("AnonymousMemoryCache" ) |
706 | .Output("handle: resource" ) |
707 | .Output("deleter: variant" ) |
708 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
709 | c->set_output(0, c->Scalar()); |
710 | c->set_output(1, c->Scalar()); |
711 | return OkStatus(); |
712 | }); |
713 | |
714 | REGISTER_OP("DeleteMemoryCache" ) |
715 | .Input("handle: resource" ) |
716 | .Input("deleter: variant" ) |
717 | .SetShapeFn(shape_inference::NoOutputs); |
718 | |
719 | REGISTER_OP("DummyMemoryCache" ) |
720 | .Output("handle: resource" ) |
721 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
722 | c->set_output(0, c->Scalar()); |
723 | return OkStatus(); |
724 | }); |
725 | |
726 | REGISTER_OP("CacheDataset" ) |
727 | .Input("input_dataset: variant" ) |
728 | .Input("filename: string" ) |
729 | .Output("handle: variant" ) |
730 | .Attr("output_types: list(type) >= 1" ) |
731 | .Attr("output_shapes: list(shape) >= 1" ) |
732 | .Attr("metadata: string = ''" ) |
733 | // TODO(mdan): Should these use type inference instead? |
734 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
735 | "output_types" )) |
736 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
737 | shape_inference::ShapeHandle unused; |
738 | // filename should be a scalar. |
739 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
740 | return shape_inference::ScalarShape(c); |
741 | }); |
742 | |
743 | REGISTER_OP("CacheDatasetV2" ) |
744 | .Input("input_dataset: variant" ) |
745 | .Input("filename: string" ) |
746 | .Input("cache: resource" ) |
747 | .Output("handle: variant" ) |
748 | .Attr("output_types: list(type) >= 1" ) |
749 | .Attr("output_shapes: list(shape) >= 1" ) |
750 | .Attr("metadata: string = ''" ) |
751 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
752 | "output_types" )) |
753 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
754 | shape_inference::ShapeHandle unused; |
755 | // filename should be a scalar. |
756 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
757 | // cache should be a scalar. |
758 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
759 | return shape_inference::ScalarShape(c); |
760 | }); |
761 | |
762 | REGISTER_OP("TextLineDataset" ) |
763 | .Input("filenames: string" ) |
764 | .Input("compression_type: string" ) |
765 | .Input("buffer_size: int64" ) |
766 | .Attr("metadata: string = ''" ) |
767 | .Output("handle: variant" ) |
768 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
769 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET, |
770 | TFT_STRING)) |
771 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
772 | shape_inference::ShapeHandle unused; |
773 | // `filenames` must be a scalar or a vector. |
774 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); |
775 | // `compression_type` could only be a scalar. |
776 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
777 | // `buffer_size` could only be a scalar. |
778 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
779 | return shape_inference::ScalarShape(c); |
780 | }); |
781 | |
782 | REGISTER_OP("FixedLengthRecordDataset" ) |
783 | .Input("filenames: string" ) |
784 | .Input("header_bytes: int64" ) |
785 | .Input("record_bytes: int64" ) |
786 | .Input("footer_bytes: int64" ) |
787 | .Input("buffer_size: int64" ) |
788 | .Attr("metadata: string = ''" ) |
789 | .Output("handle: variant" ) |
790 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
791 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET, |
792 | TFT_STRING)) |
793 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
794 | shape_inference::ShapeHandle unused; |
795 | // `filenames` must be a scalar or a vector. |
796 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); |
797 | // header_bytes, record_bytes, footer_bytes, buffer_size should be |
798 | // scalars. |
799 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
800 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
801 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
802 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
803 | return shape_inference::ScalarShape(c); |
804 | }); |
805 | |
806 | REGISTER_OP("FixedLengthRecordDatasetV2" ) |
807 | .Input("filenames: string" ) |
808 | .Input("header_bytes: int64" ) |
809 | .Input("record_bytes: int64" ) |
810 | .Input("footer_bytes: int64" ) |
811 | .Input("buffer_size: int64" ) |
812 | .Input("compression_type: string" ) |
813 | .Attr("metadata: string = ''" ) |
814 | .Output("handle: variant" ) |
815 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
816 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET, |
817 | TFT_STRING)) |
818 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
819 | shape_inference::ShapeHandle unused; |
820 | // `filenames` must be a scalar or a vector. |
821 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); |
822 | // header_bytes, record_bytes, footer_bytes, buffer_size should be |
823 | // scalars. |
824 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
825 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
826 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
827 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
828 | return shape_inference::ScalarShape(c); |
829 | }); |
830 | |
831 | REGISTER_OP("TFRecordDataset" ) |
832 | .Input("filenames: string" ) |
833 | .Input("compression_type: string" ) |
834 | .Input("buffer_size: int64" ) |
835 | .Attr("metadata: string = ''" ) |
836 | .Output("handle: variant" ) |
837 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
838 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET, |
839 | TFT_STRING)) |
840 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
841 | shape_inference::ShapeHandle unused; |
842 | // `filenames` must be a scalar or a vector. |
843 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); |
844 | // `compression_type` could only be a scalar. |
845 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
846 | // `buffer_size` could only be a scalar. |
847 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
848 | return shape_inference::ScalarShape(c); |
849 | }); |
850 | |
851 | REGISTER_OP("Iterator" ) |
852 | .Output("handle: resource" ) |
853 | .Attr("shared_name: string" ) |
854 | .Attr("container: string" ) |
855 | .Attr("output_types: list(type) >= 1" ) |
856 | .Attr("output_shapes: list(shape) >= 1" ) |
857 | .SetShapeFn(shape_inference::ScalarShape); |
858 | |
859 | REGISTER_OP("IteratorV2" ) |
860 | .Output("handle: resource" ) |
861 | .Attr("shared_name: string" ) |
862 | .Attr("container: string" ) |
863 | .Attr("output_types: list(type) >= 1" ) |
864 | .Attr("output_shapes: list(shape) >= 1" ) |
865 | .SetShapeFn(shape_inference::ScalarShape); |
866 | |
867 | REGISTER_OP("AnonymousIterator" ) |
868 | .Output("handle: resource" ) |
869 | .Attr("output_types: list(type) >= 1" ) |
870 | .Attr("output_shapes: list(shape) >= 1" ) |
871 | .SetShapeFn(shape_inference::ScalarShape); |
872 | |
873 | REGISTER_OP("AnonymousIteratorV2" ) |
874 | .Output("handle: resource" ) |
875 | .Output("deleter: variant" ) |
876 | .Attr("output_types: list(type) >= 1" ) |
877 | .Attr("output_shapes: list(shape) >= 1" ) |
878 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
879 | c->set_output(0, c->Scalar()); |
880 | c->set_output(1, c->Scalar()); |
881 | return OkStatus(); |
882 | }); |
883 | |
884 | REGISTER_OP("AnonymousIteratorV3" ) |
885 | .Output("handle: resource" ) |
886 | .Attr("output_types: list(type) >= 1" ) |
887 | .Attr("output_shapes: list(shape) >= 1" ) |
888 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
889 | c->set_output(0, c->Scalar()); |
890 | return OkStatus(); |
891 | }); |
892 | |
893 | REGISTER_OP("DeleteIterator" ) |
894 | .Input("handle: resource" ) |
895 | .Input("deleter: variant" ) |
896 | .SetShapeFn(shape_inference::NoOutputs); |
897 | |
898 | REGISTER_OP("DeleteMultiDeviceIterator" ) |
899 | .Input("multi_device_iterator: resource" ) |
900 | .Input("iterators: N * resource" ) |
901 | .Input("deleter: variant" ) |
902 | .Attr("N: int >= 0" ) |
903 | .SetShapeFn(shape_inference::NoOutputs); |
904 | |
905 | REGISTER_OP("MakeIterator" ) |
906 | .Input("dataset: variant" ) |
907 | .Input("iterator: resource" ) |
908 | .SetTypeConstructor(full_type::NoOutputs()) |
909 | .SetReverseTypeFn(1, full_type::MapCovariant(TFT_DATASET, TFT_ITERATOR, 0)) |
910 | .SetShapeFn(shape_inference::NoOutputs); |
911 | |
912 | REGISTER_OP("OneShotIterator" ) |
913 | .Output("handle: resource" ) |
914 | .Attr("dataset_factory: func" ) |
915 | .Attr("output_types: list(type) >= 1" ) |
916 | .Attr("output_shapes: list(shape) >= 1" ) |
917 | .Attr("container: string = ''" ) |
918 | .Attr("shared_name: string = ''" ) |
919 | .SetIsStateful() |
920 | .SetShapeFn(shape_inference::ScalarShape); |
921 | |
922 | REGISTER_OP("IteratorGetNext" ) |
923 | .Input("iterator: resource" ) |
924 | .Output("components: output_types" ) |
925 | .Attr("output_types: list(type) >= 1" ) |
926 | .Attr("output_shapes: list(shape) >= 1" ) |
927 | .SetShapeFn(shape_inference::DatasetIteratorShape); |
928 | |
929 | REGISTER_OP("IteratorGetNextSync" ) |
930 | .Input("iterator: resource" ) |
931 | .Output("components: output_types" ) |
932 | .Attr("output_types: list(type) >= 1" ) |
933 | .Attr("output_shapes: list(shape) >= 1" ) |
934 | .SetShapeFn(shape_inference::DatasetIteratorShape); |
935 | |
936 | // TODO(b/124308596): Instead of conservatively marking this op as stateful, |
937 | // implement a mechanism to determine whether `dataset` has a side-effect |
938 | // and use it to decide whether to use a stateless or stateful version of this |
939 | // op. |
940 | REGISTER_OP("DatasetToSingleElement" ) |
941 | .Input("dataset: variant" ) |
942 | .Output("components: output_types" ) |
943 | .Attr("output_types: list(type) >= 1" ) |
944 | .Attr("output_shapes: list(shape) >= 1" ) |
945 | .Attr("metadata: string = ''" ) |
946 | .SetIsStateful() |
947 | .SetShapeFn(shape_inference::DatasetIteratorShape); |
948 | |
949 | // TODO(b/124308596): Instead of conservatively marking this op as stateful, |
950 | // implement a mechanism to determine whether `dataset` has a side-effect |
951 | // and use it to decide whether to use a stateless or stateful version of this |
952 | // op. |
953 | REGISTER_OP("ReduceDataset" ) |
954 | .Input("input_dataset: variant" ) |
955 | .Input("initial_state: Tstate" ) |
956 | .Input("other_arguments: Targuments" ) |
957 | .Output("components: output_types" ) |
958 | .Attr("f: func" ) |
959 | .Attr("Tstate: list(type) >= 1" ) |
960 | .Attr("Targuments: list(type) >= 0" ) |
961 | .Attr("output_types: list(type) >= 1" ) |
962 | .Attr("output_shapes: list(shape) >= 1" ) |
963 | .Attr("use_inter_op_parallelism: bool = true" ) |
964 | .Attr("metadata: string = ''" ) |
965 | .SetIsStateful() |
966 | .SetShapeFn(shape_inference::DatasetIteratorShape); |
967 | |
968 | REGISTER_OP("IteratorToStringHandle" ) |
969 | .Input("resource_handle: resource" ) |
970 | .Output("string_handle: string" ) |
971 | .SetShapeFn(shape_inference::ScalarShape); |
972 | |
973 | REGISTER_OP("IteratorFromStringHandle" ) |
974 | .Input("string_handle: string" ) |
975 | .Output("resource_handle: resource" ) |
976 | .Attr("output_types: list(type) >= 0 = []" ) |
977 | .Attr("output_shapes: list(shape) >= 0 = []" ) |
978 | .SetShapeFn(shape_inference::ScalarShape); |
979 | |
980 | REGISTER_OP("IteratorFromStringHandleV2" ) |
981 | .Input("string_handle: string" ) |
982 | .Output("resource_handle: resource" ) |
983 | .Attr("output_types: list(type) >= 0 = []" ) |
984 | .Attr("output_shapes: list(shape) >= 0 = []" ) |
985 | .SetShapeFn(shape_inference::ScalarShape); |
986 | |
987 | REGISTER_OP("SerializeIterator" ) |
988 | .Input("resource_handle: resource" ) |
989 | .Attr("external_state_policy: int = 0" ) |
990 | .Output("serialized: variant" ) |
991 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
992 | c->set_output(0, c->Vector(c->UnknownDim())); |
993 | return OkStatus(); |
994 | }); |
995 | |
996 | REGISTER_OP("DeserializeIterator" ) |
997 | .Input("resource_handle: resource" ) |
998 | .Input("serialized: variant" ) |
999 | .SetShapeFn(shape_inference::NoOutputs); |
1000 | |
1001 | REGISTER_OP("DatasetToGraph" ) |
1002 | .Input("input_dataset: variant" ) |
1003 | .Attr("stateful_whitelist: list(string) >= 0 = []" ) |
1004 | .Attr("allow_stateful: bool = false" ) |
1005 | .Attr("strip_device_assignment: bool = false" ) |
1006 | .Output("graph: string" ) |
1007 | .SetShapeFn(shape_inference::ScalarShape); |
1008 | |
1009 | REGISTER_OP("DatasetToGraphV2" ) |
1010 | .Input("input_dataset: variant" ) |
1011 | .Attr("external_state_policy: int = 0" ) |
1012 | .Attr("strip_device_assignment: bool = false" ) |
1013 | .Output("graph: string" ) |
1014 | .SetForwardTypeFn(full_type::Encode(TFT_STRING, 0)) |
1015 | .SetShapeFn(shape_inference::ScalarShape); |
1016 | |
1017 | REGISTER_OP("OptimizeDataset" ) |
1018 | .Input("input_dataset: variant" ) |
1019 | .Input("optimizations: string" ) |
1020 | .Output("handle: variant" ) |
1021 | .Attr("output_types: list(type) >= 1" ) |
1022 | .Attr("output_shapes: list(shape) >= 1" ) |
1023 | .Attr("optimization_configs: list(string) = []" ) |
1024 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1025 | "output_types" )) |
1026 | .SetShapeFn(shape_inference::ScalarShape); |
1027 | |
1028 | REGISTER_OP("OptimizeDatasetV2" ) |
1029 | .Input("input_dataset: variant" ) |
1030 | .Input("optimizations_enabled: string" ) |
1031 | .Input("optimizations_disabled: string" ) |
1032 | .Input("optimizations_default: string" ) |
1033 | .Output("handle: variant" ) |
1034 | .Attr("output_types: list(type) >= 1" ) |
1035 | .Attr("output_shapes: list(shape) >= 1" ) |
1036 | .Attr("optimization_configs: list(string) = []" ) |
1037 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1038 | "output_types" )) |
1039 | .SetShapeFn(shape_inference::ScalarShape); |
1040 | |
1041 | REGISTER_OP("OptionalFromValue" ) |
1042 | .Input("components: Toutput_types" ) |
1043 | .Output("optional: variant" ) |
1044 | .Attr("Toutput_types: list(type) >= 1" ) |
1045 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_OPTIONAL, |
1046 | "Toutput_types" )) |
1047 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1048 | std::vector<DataType> dtypes; |
1049 | TF_RETURN_IF_ERROR(c->GetAttr("Toutput_types" , &dtypes)); |
1050 | c->set_output(0, c->Scalar()); |
1051 | std::vector<shape_inference::ShapeAndType> shapes_and_types; |
1052 | shapes_and_types.reserve(c->num_inputs()); |
1053 | const FullTypeDef& ret_types = c->ret_types(); |
1054 | for (int i = 0; i < c->num_inputs(); ++i) { |
1055 | // TODO(mdan): output_type(i) == optional is incorrect. |
1056 | // "Optional" is the type of the whole container, not of individual |
1057 | // elements. |
1058 | // |
1059 | // Why ret_types.args(0) and not args(i) -- |
1060 | // For example if Toutput_types is (int32, float32), then |
1061 | // ret_types.args[0] (i.e. the 0th output) is |
1062 | // Optional[Record[Tensor[int32, s1], Tensor[float32, s2]]] |
1063 | // set_output_handle_shapes_and_types tracks the same thing, but in |
1064 | // a transposed way: |
1065 | // {ShapeAndType(in32, s1, Optional), ShapeAndType(in32, s2, Optional)} |
1066 | // That should be corrected in the future (see todo above). |
1067 | shapes_and_types.emplace_back(c->input(i), dtypes[i], |
1068 | ret_types.args(0)); |
1069 | } |
1070 | c->set_output_handle_shapes_and_types(0, shapes_and_types); |
1071 | return OkStatus(); |
1072 | }); |
1073 | |
1074 | REGISTER_OP("OptionalNone" ) |
1075 | .Output("optional: variant" ) |
1076 | .SetShapeFn(shape_inference::ScalarShape); |
1077 | |
1078 | REGISTER_OP("OptionalHasValue" ) |
1079 | .Input("optional: variant" ) |
1080 | .Output("has_value: bool" ) |
1081 | .SetShapeFn(shape_inference::ScalarShape); |
1082 | |
1083 | REGISTER_OP("OptionalGetValue" ) |
1084 | .Input("optional: variant" ) |
1085 | .Output("components: output_types" ) |
1086 | .Attr("output_types: list(type) >= 1" ) |
1087 | .Attr("output_shapes: list(shape) >= 1" ) |
1088 | .SetShapeFn(shape_inference::DatasetIteratorShape); |
1089 | |
1090 | REGISTER_OP("IteratorGetNextAsOptional" ) |
1091 | .Input("iterator: resource" ) |
1092 | .Output("optional: variant" ) |
1093 | .Attr("output_types: list(type) >= 1" ) |
1094 | .Attr("output_shapes: list(shape) >= 1" ) |
1095 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_OPTIONAL, |
1096 | "output_types" )) |
1097 | .SetForwardTypeFn(full_type::MapCovariant(TFT_ITERATOR, TFT_OPTIONAL, 0)) |
1098 | .SetShapeFn(shape_inference::ScalarShape); |
1099 | |
1100 | REGISTER_OP("ModelDataset" ) |
1101 | .Input("input_dataset: variant" ) |
1102 | .Output("handle: variant" ) |
1103 | .Attr("algorithm: int = 0" ) |
1104 | .Attr("cpu_budget: int = 0" ) |
1105 | .Attr("ram_budget: int = 0" ) |
1106 | .Attr("output_types: list(type) >= 1" ) |
1107 | .Attr("output_shapes: list(shape) >= 1" ) |
1108 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1109 | "output_types" )) |
1110 | .SetShapeFn(shape_inference::ScalarShape); |
1111 | |
1112 | // TODO(b/124308749): Add a stateful version of MapDefun and use it when `f` |
1113 | // is stateful. |
1114 | REGISTER_OP("MapDefun" ) |
1115 | .Input("arguments: Targuments" ) |
1116 | .Input("captured_inputs: Tcaptured" ) |
1117 | .Output("output: output_types" ) |
1118 | .Attr("Targuments: list(type) >= 1" ) |
1119 | .Attr("Tcaptured: list(type) >= 0 = []" ) |
1120 | .Attr("output_types: list(type) >= 1" ) |
1121 | .Attr("output_shapes: list(shape) >= 1" ) |
1122 | .Attr("f: func" ) |
1123 | .Attr("max_intra_op_parallelism: int = 1" ) |
1124 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1125 | std::vector<PartialTensorShape> output_shapes; |
1126 | TF_RETURN_IF_ERROR(c->GetAttr("output_shapes" , &output_shapes)); |
1127 | DataTypeVector t_args; |
1128 | TF_RETURN_IF_ERROR(c->GetAttr("Targuments" , &t_args)); |
1129 | if (output_shapes.size() != c->num_outputs()) { |
1130 | return errors::InvalidArgument( |
1131 | "`output_shapes` must be the same length as `output_types` (" , |
1132 | output_shapes.size(), " vs. " , c->num_outputs(), ")" ); |
1133 | } |
1134 | |
1135 | int64_t dim_zero = -1; |
1136 | for (size_t i = 0; i < t_args.size(); ++i) { |
1137 | if (c->Rank(c->input(i)) == 0) { |
1138 | return errors::InvalidArgument( |
1139 | "Arguments must have rank at least 1. Input " , i, |
1140 | " has rank of 0." ); |
1141 | } |
1142 | auto dim_handle = c->Dim(c->input(i), 0); |
1143 | if (c->ValueKnown(dim_handle)) { |
1144 | if (dim_zero == -1) { |
1145 | dim_zero = c->Value(dim_handle); |
1146 | } else if (c->Value(dim_handle) != dim_zero) { |
1147 | return errors::InvalidArgument( |
1148 | "Arguments must have the same dimension 0." ); |
1149 | } |
1150 | } |
1151 | } |
1152 | |
1153 | for (size_t i = 0; i < output_shapes.size(); ++i) { |
1154 | PartialTensorShape s({}); |
1155 | s = s.Concatenate(dim_zero); |
1156 | s = s.Concatenate(output_shapes[i]); |
1157 | shape_inference::ShapeHandle output_shape_handle; |
1158 | |
1159 | TF_RETURN_IF_ERROR( |
1160 | c->MakeShapeFromPartialTensorShape(s, &output_shape_handle)); |
1161 | c->set_output(static_cast<int>(i), output_shape_handle); |
1162 | } |
1163 | return OkStatus(); |
1164 | }); |
1165 | |
1166 | REGISTER_OP("WrapDatasetVariant" ) |
1167 | .Input("input_handle: variant" ) |
1168 | .Output("output_handle: variant" ) |
1169 | .SetShapeFn(shape_inference::ScalarShape); |
1170 | |
1171 | REGISTER_OP("UnwrapDatasetVariant" ) |
1172 | .Input("input_handle: variant" ) |
1173 | .Output("output_handle: variant" ) |
1174 | .SetShapeFn(shape_inference::ScalarShape); |
1175 | |
1176 | REGISTER_OP("AnonymousMultiDeviceIterator" ) |
1177 | .Output("handle: resource" ) |
1178 | .Output("deleter: variant" ) |
1179 | .Attr("devices: list(string) >= 1" ) |
1180 | .Attr("output_types: list(type) >= 1" ) |
1181 | .Attr("output_shapes: list(shape) >= 1" ) |
1182 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1183 | c->set_output(0, c->Scalar()); |
1184 | c->set_output(1, c->Scalar()); |
1185 | return OkStatus(); |
1186 | }); |
1187 | |
1188 | REGISTER_OP("AnonymousMultiDeviceIteratorV3" ) |
1189 | .Output("handle: resource" ) |
1190 | .Attr("devices: list(string) >= 1" ) |
1191 | .Attr("output_types: list(type) >= 1" ) |
1192 | .Attr("output_shapes: list(shape) >= 1" ) |
1193 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1194 | c->set_output(0, c->Scalar()); |
1195 | return OkStatus(); |
1196 | }); |
1197 | |
1198 | REGISTER_OP("MultiDeviceIterator" ) |
1199 | .Output("handle: resource" ) |
1200 | .Attr("devices: list(string) >= 1" ) |
1201 | .Attr("shared_name: string" ) |
1202 | .Attr("container: string" ) |
1203 | .Attr("output_types: list(type) >= 1" ) |
1204 | .Attr("output_shapes: list(shape) >= 1" ) |
1205 | .SetShapeFn(shape_inference::ScalarShape); |
1206 | |
1207 | REGISTER_OP("MultiDeviceIteratorInit" ) |
1208 | .Input("dataset: variant" ) |
1209 | .Input("multi_device_iterator: resource" ) |
1210 | .Input("max_buffer_size: int64" ) |
1211 | .Output("incarnation_id: int64" ) |
1212 | .SetShapeFn(shape_inference::ScalarShape); |
1213 | |
1214 | REGISTER_OP("MultiDeviceIteratorGetNextFromShard" ) |
1215 | .Input("multi_device_iterator: resource" ) |
1216 | .Input("shard_num: int32" ) |
1217 | .Input("incarnation_id: int64" ) |
1218 | .Output("components: output_types" ) |
1219 | .Attr("output_types: list(type) >= 1" ) |
1220 | .Attr("output_shapes: list(shape) >= 1" ) |
1221 | .SetShapeFn(shape_inference::DatasetIteratorShape); |
1222 | |
1223 | REGISTER_OP("MultiDeviceIteratorToStringHandle" ) |
1224 | .Input("multi_device_iterator: resource" ) |
1225 | .Output("string_handle: string" ) |
1226 | .SetForwardTypeFn(full_type::Encode(TFT_STRING, 0)) |
1227 | .SetShapeFn(shape_inference::ScalarShape); |
1228 | |
1229 | REGISTER_OP("MultiDeviceIteratorFromStringHandle" ) |
1230 | .Input("string_handle: string" ) |
1231 | .Output("multi_device_iterator: resource" ) |
1232 | .Attr("output_types: list(type) >= 0 = []" ) |
1233 | .Attr("output_shapes: list(shape) >= 0 = []" ) |
1234 | .SetForwardTypeFn(full_type::Decode(TFT_STRING, 0)) |
1235 | .SetShapeFn(shape_inference::ScalarShape); |
1236 | |
1237 | REGISTER_OP("OptionsDataset" ) |
1238 | .Input("input_dataset: variant" ) |
1239 | .Output("handle: variant" ) |
1240 | .Attr("serialized_options: string" ) |
1241 | .Attr("output_types: list(type) >= 1" ) |
1242 | .Attr("output_shapes: list(shape) >= 1" ) |
1243 | .Attr("metadata: string = ''" ) |
1244 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1245 | "output_types" )) |
1246 | .SetShapeFn(shape_inference::ScalarShape); |
1247 | |
1248 | REGISTER_OP("GetOptions" ) |
1249 | .Input("input_dataset: variant" ) |
1250 | .Output("serialized_options: string" ) |
1251 | .SetShapeFn(shape_inference::ScalarShape); |
1252 | |
1253 | REGISTER_OP("FinalizeDataset" ) |
1254 | .Input("input_dataset: variant" ) |
1255 | .Output("handle: variant" ) |
1256 | .Attr("has_captured_ref: bool = false" ) |
1257 | .Attr("output_types: list(type) >= 1" ) |
1258 | .Attr("output_shapes: list(shape) >= 1" ) |
1259 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1260 | "output_types" )) |
1261 | .SetShapeFn(shape_inference::ScalarShape); |
1262 | |
1263 | } // namespace tensorflow |
1264 | |