1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/framework/common_shape_fns.h" |
16 | #include "tensorflow/core/framework/full_type.pb.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | |
19 | namespace tensorflow { |
20 | |
21 | REGISTER_OP("AssertCardinalityDataset" ) |
22 | .Input("input_dataset: variant" ) |
23 | .Input("cardinality: int64" ) |
24 | .Output("handle: variant" ) |
25 | .Attr("output_types: list(type) >= 1" ) |
26 | .Attr("output_shapes: list(shape) >= 1" ) |
27 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
28 | "output_types" )) |
29 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
30 | shape_inference::ShapeHandle unused; |
31 | // cardinality should be a scalar. |
32 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
33 | return shape_inference::ScalarShape(c); |
34 | }); |
35 | |
36 | REGISTER_OP("AssertNextDataset" ) |
37 | .Input("input_dataset: variant" ) |
38 | .Input("transformations: string" ) |
39 | .Output("handle: variant" ) |
40 | .Attr("output_types: list(type) >= 1" ) |
41 | .Attr("output_shapes: list(shape) >= 1" ) |
42 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
43 | "output_types" )) |
44 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
45 | shape_inference::ShapeHandle unused; |
46 | // transformations should be a vector. |
47 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
48 | return shape_inference::ScalarShape(c); |
49 | }); |
50 | |
51 | REGISTER_OP("ExperimentalAssertNextDataset" ) |
52 | .Input("input_dataset: variant" ) |
53 | .Input("transformations: string" ) |
54 | .Output("handle: variant" ) |
55 | .Attr("output_types: list(type) >= 1" ) |
56 | .Attr("output_shapes: list(shape) >= 1" ) |
57 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
58 | "output_types" )) |
59 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
60 | shape_inference::ShapeHandle unused; |
61 | // transformations should be a vector. |
62 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
63 | return shape_inference::ScalarShape(c); |
64 | }); |
65 | |
66 | REGISTER_OP("AssertPrevDataset" ) |
67 | .Input("input_dataset: variant" ) |
68 | .Input("transformations: string" ) |
69 | .Output("handle: variant" ) |
70 | .Attr("output_types: list(type) >= 1" ) |
71 | .Attr("output_shapes: list(shape) >= 1" ) |
72 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
73 | "output_types" )) |
74 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
75 | shape_inference::ShapeHandle unused; |
76 | // transformations should be a vector. |
77 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
78 | return shape_inference::ScalarShape(c); |
79 | }); |
80 | |
81 | REGISTER_OP("AutoShardDataset" ) |
82 | .Input("input_dataset: variant" ) |
83 | .Input("num_workers: int64" ) |
84 | .Input("index: int64" ) |
85 | .Output("handle: variant" ) |
86 | .Attr("auto_shard_policy: int = 0" ) |
87 | .Attr("output_types: list(type) >= 1" ) |
88 | .Attr("output_shapes: list(shape) >= 1" ) |
89 | .Attr("num_replicas: int = 0" ) |
90 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
91 | "output_types" )) |
92 | .SetForwardTypeFn(full_type::ContainerMap(TFT_DATASET, /*input_idx=*/0, |
93 | full_type::ShardTensor)) |
94 | .SetShapeFn(shape_inference::ScalarShape); |
95 | |
96 | REGISTER_OP("ExperimentalAutoShardDataset" ) |
97 | .Input("input_dataset: variant" ) |
98 | .Input("num_workers: int64" ) |
99 | .Input("index: int64" ) |
100 | .Output("handle: variant" ) |
101 | .Attr("auto_shard_policy: int = 0" ) |
102 | .Attr("output_types: list(type) >= 1" ) |
103 | .Attr("output_shapes: list(shape) >= 1" ) |
104 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
105 | "output_types" )) |
106 | .SetShapeFn(shape_inference::ScalarShape); |
107 | |
108 | REGISTER_OP("BytesProducedStatsDataset" ) |
109 | .Input("input_dataset: variant" ) |
110 | .Input("tag: string" ) |
111 | .Output("handle: variant" ) |
112 | .Attr("output_types: list(type) >= 1" ) |
113 | .Attr("output_shapes: list(shape) >= 1" ) |
114 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
115 | "output_types" )) |
116 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
117 | shape_inference::ShapeHandle tag_shape; |
118 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape)); |
119 | return shape_inference::ScalarShape(c); |
120 | }); |
121 | |
122 | REGISTER_OP("ExperimentalBytesProducedStatsDataset" ) |
123 | .Input("input_dataset: variant" ) |
124 | .Input("tag: string" ) |
125 | .Output("handle: variant" ) |
126 | .Attr("output_types: list(type) >= 1" ) |
127 | .Attr("output_shapes: list(shape) >= 1" ) |
128 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
129 | "output_types" )) |
130 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
131 | shape_inference::ShapeHandle tag_shape; |
132 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape)); |
133 | return shape_inference::ScalarShape(c); |
134 | }); |
135 | |
136 | REGISTER_OP("ChooseFastestBranchDataset" ) |
137 | .Input("input_dataset: variant" ) |
138 | .Input("ratio_numerator: int64" ) |
139 | .Input("ratio_denominator: int64" ) |
140 | .Input("other_arguments: Targuments" ) |
141 | .Output("handle: variant" ) |
142 | .Attr("Targuments: list(type) >= 0" ) |
143 | .Attr("num_elements_per_branch: int >= 1" ) |
144 | .Attr("branches: list(func) >= 1" ) |
145 | .Attr("other_arguments_lengths: list(int) >= 1" ) |
146 | .Attr("output_types: list(type) >= 1" ) |
147 | .Attr("output_shapes: list(shape) >= 1" ) |
148 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
149 | "output_types" )) |
150 | .SetShapeFn(shape_inference::ScalarShape); |
151 | |
152 | REGISTER_OP("ChooseFastestDataset" ) |
153 | .Input("input_datasets: N * variant" ) |
154 | .Output("handle: variant" ) |
155 | .Attr("N: int >= 2" ) |
156 | .Attr("num_experiments: int" ) |
157 | .Attr("output_types: list(type) >= 1" ) |
158 | .Attr("output_shapes: list(shape) >= 1" ) |
159 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
160 | "output_types" )) |
161 | .SetShapeFn(shape_inference::ScalarShape); |
162 | |
163 | REGISTER_OP("ExperimentalChooseFastestDataset" ) |
164 | .Input("input_datasets: N * variant" ) |
165 | .Output("handle: variant" ) |
166 | .Attr("N: int >= 2" ) |
167 | .Attr("num_experiments: int" ) |
168 | .Attr("output_types: list(type) >= 1" ) |
169 | .Attr("output_shapes: list(shape) >= 1" ) |
170 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
171 | "output_types" )) |
172 | .SetShapeFn(shape_inference::ScalarShape); |
173 | |
174 | REGISTER_OP("CompressElement" ) |
175 | .Input("components: input_types" ) |
176 | .Output("compressed: variant" ) |
177 | .Attr("input_types: list(type) >= 1" ) |
178 | .SetShapeFn(shape_inference::ScalarShape); |
179 | |
180 | REGISTER_OP("UncompressElement" ) |
181 | .Input("compressed: variant" ) |
182 | .Output("components: output_types" ) |
183 | .Attr("output_types: list(type) >= 1" ) |
184 | .Attr("output_shapes: list(shape) >= 1" ) |
185 | .SetShapeFn(shape_inference::DatasetIteratorShape); |
186 | |
187 | REGISTER_OP("ComputeBatchSize" ) |
188 | .Input("input_dataset : variant" ) |
189 | .Output("batch_size : int64" ) |
190 | .SetShapeFn(shape_inference::ScalarShape); |
191 | |
192 | REGISTER_OP("CSVDataset" ) |
193 | .Input("filenames: string" ) |
194 | .Input("compression_type: string" ) |
195 | .Input("buffer_size: int64" ) |
196 | .Input("header: bool" ) |
197 | .Input("field_delim: string" ) |
198 | .Input("use_quote_delim: bool" ) |
199 | .Input("na_value: string" ) |
200 | .Input("select_cols: int64" ) |
201 | .Input("record_defaults: output_types" ) |
202 | .Output("handle: variant" ) |
203 | .Attr("output_types: list({float,double,int32,int64,string}) >= 1" ) |
204 | .Attr("output_shapes: list(shape) >= 1" ) |
205 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
206 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
207 | "output_types" )) |
208 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
209 | shape_inference::ShapeHandle unused; |
210 | // `filenames` must be a scalar or a vector. |
211 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); |
212 | // `compression_type`, `buffer_size`, `header`, `field_delim`, |
213 | // `use_quote_delim`, `na_value` must be scalars |
214 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
215 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
216 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
217 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
218 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
219 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
220 | // `select_cols` must be a vector |
221 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused)); |
222 | // `record_defaults` must be lists of scalars |
223 | for (size_t i = 8; i < c->num_inputs(); ++i) { |
224 | shape_inference::ShapeHandle v; |
225 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v)); |
226 | if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) { |
227 | return errors::InvalidArgument( |
228 | "Shape of a default must be a length-0 or length-1 vector, or a " |
229 | "scalar." ); |
230 | } |
231 | } |
232 | return shape_inference::ScalarShape(c); |
233 | }); |
234 | |
235 | REGISTER_OP("CSVDatasetV2" ) |
236 | .Input("filenames: string" ) |
237 | .Input("compression_type: string" ) |
238 | .Input("buffer_size: int64" ) |
239 | .Input("header: bool" ) |
240 | .Input("field_delim: string" ) |
241 | .Input("use_quote_delim: bool" ) |
242 | .Input("na_value: string" ) |
243 | .Input("select_cols: int64" ) |
244 | .Input("record_defaults: output_types" ) |
245 | .Input("exclude_cols: int64" ) |
246 | .Output("handle: variant" ) |
247 | .Attr("output_types: list({float,double,int32,int64,string}) >= 1" ) |
248 | .Attr("output_shapes: list(shape) >= 1" ) |
249 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
250 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
251 | "output_types" )) |
252 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
253 | shape_inference::ShapeHandle unused; |
254 | // `filenames` must be a scalar or a vector. |
255 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); |
256 | // `compression_type`, `buffer_size`, `header`, `field_delim`, |
257 | // `use_quote_delim`, `na_value` must be scalars |
258 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
259 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
260 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
261 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
262 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
263 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
264 | // `select_cols` must be a vector |
265 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused)); |
266 | // `exclude_cols` must be a vector |
267 | TF_RETURN_IF_ERROR( |
268 | c->WithRank(c->input(c->num_inputs() - 1), 1, &unused)); |
269 | // `record_defaults` must be lists of scalars |
270 | for (size_t i = 8; i < c->num_inputs() - 1; ++i) { |
271 | shape_inference::ShapeHandle v; |
272 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v)); |
273 | if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) { |
274 | return errors::InvalidArgument( |
275 | "Shape of a default must be a length-0 or length-1 vector, or a " |
276 | "scalar." ); |
277 | } |
278 | } |
279 | return shape_inference::ScalarShape(c); |
280 | }); |
281 | |
282 | REGISTER_OP("ExperimentalCSVDataset" ) |
283 | .Input("filenames: string" ) |
284 | .Input("compression_type: string" ) |
285 | .Input("buffer_size: int64" ) |
286 | .Input("header: bool" ) |
287 | .Input("field_delim: string" ) |
288 | .Input("use_quote_delim: bool" ) |
289 | .Input("na_value: string" ) |
290 | .Input("select_cols: int64" ) |
291 | .Input("record_defaults: output_types" ) |
292 | .Output("handle: variant" ) |
293 | .Attr("output_types: list({float,double,int32,int64,string}) >= 1" ) |
294 | .Attr("output_shapes: list(shape) >= 1" ) |
295 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
296 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
297 | "output_types" )) |
298 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
299 | shape_inference::ShapeHandle unused; |
300 | // `filenames` must be a scalar or a vector. |
301 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); |
302 | // `compression_type`, `buffer_size`, `header`, `field_delim`, |
303 | // `use_quote_delim`, `na_value` must be scalars |
304 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
305 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
306 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
307 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
308 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
309 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
310 | // `select_cols` must be a vector |
311 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused)); |
312 | // `record_defaults` must be lists of scalars |
313 | for (size_t i = 8; i < c->num_inputs(); ++i) { |
314 | shape_inference::ShapeHandle v; |
315 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v)); |
316 | if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) { |
317 | return errors::InvalidArgument( |
318 | "Shape of a default must be a length-0 or length-1 vector, or a " |
319 | "scalar." ); |
320 | } |
321 | } |
322 | return shape_inference::ScalarShape(c); |
323 | }); |
324 | |
325 | REGISTER_OP("ExperimentalDatasetCardinality" ) |
326 | .Input("input_dataset: variant" ) |
327 | .Output("cardinality: int64" ) |
328 | .SetShapeFn(shape_inference::ScalarShape); |
329 | |
330 | REGISTER_OP("DatasetFromGraph" ) |
331 | .Input("graph_def: string" ) |
332 | .Output("handle: variant" ) |
333 | .SetTypeConstructor(full_type::UnaryGeneric(TFT_DATASET)) |
334 | .SetForwardTypeFn(full_type::Decode(TFT_STRING, 0)) |
335 | .SetShapeFn(shape_inference::ScalarShape); |
336 | |
337 | // TODO(b/124308596): Instead of conservatively marking this op as stateful, |
338 | // implement a mechanism to determine whether `dataset` has a side-effect |
339 | // and use it to decide whether to use a stateless or stateful version of this |
340 | // op. |
341 | REGISTER_OP("DatasetToTFRecord" ) |
342 | .Input("input_dataset: variant" ) |
343 | .Input("filename: string" ) |
344 | .Input("compression_type: string" ) |
345 | .SetIsStateful() |
346 | .SetShapeFn(shape_inference::NoOutputs); |
347 | |
348 | REGISTER_OP("ExperimentalDatasetToTFRecord" ) |
349 | .Input("input_dataset: variant" ) |
350 | .Input("filename: string" ) |
351 | .Input("compression_type: string" ) |
352 | .SetIsStateful() |
353 | .SetShapeFn(shape_inference::NoOutputs); |
354 | |
355 | REGISTER_OP("DenseToSparseBatchDataset" ) |
356 | .Input("input_dataset: variant" ) |
357 | .Input("batch_size: int64" ) |
358 | .Input("row_shape: int64" ) |
359 | .Output("handle: variant" ) |
360 | .Attr("output_types: list(type) >= 1" ) |
361 | .Attr("output_shapes: list(shape) >= 1" ) |
362 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
363 | "output_types" )) |
364 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
365 | shape_inference::ShapeHandle unused; |
366 | // batch_size should be a scalar. |
367 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
368 | // row_shape should be a 1-D vector. |
369 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
370 | return shape_inference::ScalarShape(c); |
371 | }); |
372 | |
373 | REGISTER_OP("ExperimentalDenseToSparseBatchDataset" ) |
374 | .Input("input_dataset: variant" ) |
375 | .Input("batch_size: int64" ) |
376 | .Input("row_shape: int64" ) |
377 | .Output("handle: variant" ) |
378 | .Attr("output_types: list(type) >= 1" ) |
379 | .Attr("output_shapes: list(shape) >= 1" ) |
380 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
381 | "output_types" )) |
382 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
383 | shape_inference::ShapeHandle unused; |
384 | // batch_size should be a scalar. |
385 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
386 | // row_shape should be a 1-D vector. |
387 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
388 | return shape_inference::ScalarShape(c); |
389 | }); |
390 | |
391 | REGISTER_OP("DirectedInterleaveDataset" ) |
392 | .Input("selector_input_dataset: variant" ) |
393 | .Input("data_input_datasets: N * variant" ) |
394 | .Output("handle: variant" ) |
395 | .Attr("output_types: list(type) >= 1" ) |
396 | .Attr("output_shapes: list(shape) >= 1" ) |
397 | .Attr("N: int >= 1" ) |
398 | .Attr("stop_on_empty_dataset: bool = false" ) |
399 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
400 | "output_types" )) |
401 | .SetShapeFn(shape_inference::ScalarShape); |
402 | |
403 | REGISTER_OP("ExperimentalDirectedInterleaveDataset" ) |
404 | .Input("selector_input_dataset: variant" ) |
405 | .Input("data_input_datasets: N * variant" ) |
406 | .Output("handle: variant" ) |
407 | .Attr("output_types: list(type) >= 1" ) |
408 | .Attr("output_shapes: list(shape) >= 1" ) |
409 | .Attr("N: int >= 1" ) |
410 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
411 | "output_types" )) |
412 | .SetShapeFn(shape_inference::ScalarShape); |
413 | |
414 | REGISTER_OP("GroupByReducerDataset" ) |
415 | .Input("input_dataset: variant" ) |
416 | .Input("key_func_other_arguments: Tkey_func_other_arguments" ) |
417 | .Input("init_func_other_arguments: Tinit_func_other_arguments" ) |
418 | .Input("reduce_func_other_arguments: Treduce_func_other_arguments" ) |
419 | .Input("finalize_func_other_arguments: Tfinalize_func_other_arguments" ) |
420 | .Output("handle: variant" ) |
421 | .Attr("key_func: func" ) |
422 | .Attr("init_func: func" ) |
423 | .Attr("reduce_func: func" ) |
424 | .Attr("finalize_func: func" ) |
425 | .Attr("Tkey_func_other_arguments: list(type) >= 0" ) |
426 | .Attr("Tinit_func_other_arguments: list(type) >= 0" ) |
427 | .Attr("Treduce_func_other_arguments: list(type) >= 0" ) |
428 | .Attr("Tfinalize_func_other_arguments: list(type) >= 0" ) |
429 | .Attr("output_types: list(type) >= 1" ) |
430 | .Attr("output_shapes: list(shape) >= 1" ) |
431 | .SetIsStateful() |
432 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
433 | "output_types" )) |
434 | .SetShapeFn(shape_inference::ScalarShape); |
435 | |
436 | REGISTER_OP("ExperimentalGroupByReducerDataset" ) |
437 | .Input("input_dataset: variant" ) |
438 | .Input("key_func_other_arguments: Tkey_func_other_arguments" ) |
439 | .Input("init_func_other_arguments: Tinit_func_other_arguments" ) |
440 | .Input("reduce_func_other_arguments: Treduce_func_other_arguments" ) |
441 | .Input("finalize_func_other_arguments: Tfinalize_func_other_arguments" ) |
442 | .Output("handle: variant" ) |
443 | .Attr("key_func: func" ) |
444 | .Attr("init_func: func" ) |
445 | .Attr("reduce_func: func" ) |
446 | .Attr("finalize_func: func" ) |
447 | .Attr("Tkey_func_other_arguments: list(type) >= 0" ) |
448 | .Attr("Tinit_func_other_arguments: list(type) >= 0" ) |
449 | .Attr("Treduce_func_other_arguments: list(type) >= 0" ) |
450 | .Attr("Tfinalize_func_other_arguments: list(type) >= 0" ) |
451 | .Attr("output_types: list(type) >= 1" ) |
452 | .Attr("output_shapes: list(shape) >= 1" ) |
453 | .SetIsStateful() |
454 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
455 | "output_types" )) |
456 | .SetShapeFn(shape_inference::ScalarShape); |
457 | |
458 | REGISTER_OP("GroupByWindowDataset" ) |
459 | .Input("input_dataset: variant" ) |
460 | .Input("key_func_other_arguments: Tkey_func_other_arguments" ) |
461 | .Input("reduce_func_other_arguments: Treduce_func_other_arguments" ) |
462 | .Input( |
463 | "window_size_func_other_arguments: Twindow_size_func_other_arguments" ) |
464 | .Output("handle: variant" ) |
465 | .Attr("key_func: func" ) |
466 | .Attr("reduce_func: func" ) |
467 | .Attr("window_size_func: func" ) |
468 | .Attr("Tkey_func_other_arguments: list(type) >= 0" ) |
469 | .Attr("Treduce_func_other_arguments: list(type) >= 0" ) |
470 | .Attr("Twindow_size_func_other_arguments: list(type) >= 0" ) |
471 | .Attr("output_types: list(type) >= 1" ) |
472 | .Attr("output_shapes: list(shape) >= 1" ) |
473 | .Attr("metadata: string = ''" ) |
474 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
475 | "output_types" )) |
476 | .SetShapeFn(shape_inference::ScalarShape); |
477 | |
478 | REGISTER_OP("GetElementAtIndex" ) |
479 | .Input("dataset: variant" ) |
480 | .Input("index: int64" ) |
481 | .Output("components: output_types" ) |
482 | .Attr("output_types: list(type) >= 1" ) |
483 | .Attr("output_shapes: list(shape) >= 1" ) |
484 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
485 | "output_types" )) |
486 | .SetShapeFn(shape_inference::DatasetIteratorShape); |
487 | |
488 | REGISTER_OP("ExperimentalGroupByWindowDataset" ) |
489 | .Input("input_dataset: variant" ) |
490 | .Input("key_func_other_arguments: Tkey_func_other_arguments" ) |
491 | .Input("reduce_func_other_arguments: Treduce_func_other_arguments" ) |
492 | .Input( |
493 | "window_size_func_other_arguments: Twindow_size_func_other_arguments" ) |
494 | .Output("handle: variant" ) |
495 | .Attr("key_func: func" ) |
496 | .Attr("reduce_func: func" ) |
497 | .Attr("window_size_func: func" ) |
498 | .Attr("Tkey_func_other_arguments: list(type) >= 0" ) |
499 | .Attr("Treduce_func_other_arguments: list(type) >= 0" ) |
500 | .Attr("Twindow_size_func_other_arguments: list(type) >= 0" ) |
501 | .Attr("output_types: list(type) >= 1" ) |
502 | .Attr("output_shapes: list(shape) >= 1" ) |
503 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
504 | "output_types" )) |
505 | .SetShapeFn(shape_inference::ScalarShape); |
506 | |
507 | REGISTER_OP("IgnoreErrorsDataset" ) |
508 | .Input("input_dataset: variant" ) |
509 | .Output("handle: variant" ) |
510 | .Attr("output_types: list(type) >= 1" ) |
511 | .Attr("output_shapes: list(shape) >= 1" ) |
512 | .Attr("log_warning: bool = false" ) |
513 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
514 | "output_types" )) |
515 | .SetShapeFn(shape_inference::ScalarShape); |
516 | |
517 | REGISTER_OP("ExperimentalIgnoreErrorsDataset" ) |
518 | .Input("input_dataset: variant" ) |
519 | .Output("handle: variant" ) |
520 | .Attr("output_types: list(type) >= 1" ) |
521 | .Attr("output_shapes: list(shape) >= 1" ) |
522 | .Attr("log_warning: bool = false" ) |
523 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
524 | "output_types" )) |
525 | .SetShapeFn(shape_inference::ScalarShape); |
526 | |
527 | REGISTER_OP("IteratorGetDevice" ) |
528 | .Input("resource: resource" ) |
529 | .Output("device: string" ) |
530 | .SetShapeFn(shape_inference::ScalarShape); |
531 | |
532 | REGISTER_OP("ExperimentalIteratorGetDevice" ) |
533 | .Input("resource: resource" ) |
534 | .Output("device: string" ) |
535 | .SetShapeFn(shape_inference::ScalarShape); |
536 | |
537 | REGISTER_OP("LatencyStatsDataset" ) |
538 | .Input("input_dataset: variant" ) |
539 | .Input("tag: string" ) |
540 | .Output("handle: variant" ) |
541 | .Attr("output_types: list(type) >= 1" ) |
542 | .Attr("output_shapes: list(shape) >= 1" ) |
543 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
544 | "output_types" )) |
545 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
546 | shape_inference::ShapeHandle tag_shape; |
547 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape)); |
548 | return shape_inference::ScalarShape(c); |
549 | }); |
550 | |
551 | REGISTER_OP("ExperimentalLatencyStatsDataset" ) |
552 | .Input("input_dataset: variant" ) |
553 | .Input("tag: string" ) |
554 | .Output("handle: variant" ) |
555 | .Attr("output_types: list(type) >= 1" ) |
556 | .Attr("output_shapes: list(shape) >= 1" ) |
557 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
558 | "output_types" )) |
559 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
560 | shape_inference::ShapeHandle tag_shape; |
561 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape)); |
562 | return shape_inference::ScalarShape(c); |
563 | }); |
564 | |
565 | REGISTER_OP("LMDBDataset" ) |
566 | .Input("filenames: string" ) |
567 | .Output("handle: variant" ) |
568 | .Attr("output_types: list(type) >= 1" ) |
569 | .Attr("output_shapes: list(shape) >= 1" ) |
570 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
571 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
572 | "output_types" )) |
573 | .SetShapeFn(shape_inference::ScalarShape); |
574 | |
575 | REGISTER_OP("ExperimentalLMDBDataset" ) |
576 | .Input("filenames: string" ) |
577 | .Output("handle: variant" ) |
578 | .Attr("output_types: list(type) >= 1" ) |
579 | .Attr("output_shapes: list(shape) >= 1" ) |
580 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
581 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
582 | "output_types" )) |
583 | .SetShapeFn(shape_inference::ScalarShape); |
584 | |
585 | REGISTER_OP("MapAndBatchDataset" ) |
586 | .Input("input_dataset: variant" ) |
587 | .Input("other_arguments: Targuments" ) |
588 | .Input("batch_size: int64" ) |
589 | .Input("num_parallel_calls: int64" ) |
590 | .Input("drop_remainder: bool" ) |
591 | .Output("handle: variant" ) |
592 | .Attr("f: func" ) |
593 | .Attr("Targuments: list(type) >= 0" ) |
594 | .Attr("output_types: list(type) >= 1" ) |
595 | .Attr("output_shapes: list(shape) >= 1" ) |
596 | .Attr("preserve_cardinality: bool = false" ) |
597 | .Attr("metadata: string = ''" ) |
598 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
599 | "output_types" )) |
600 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
601 | // Use index from the end to retrieve the Input shapes, |
602 | // so that to avoid guessing the length of "other_arguments". |
603 | // batch_size, num_parallel_calls, and drop_remainder are 0-D scalars. |
604 | shape_inference::ShapeHandle unused; |
605 | TF_RETURN_IF_ERROR( |
606 | c->WithRank(c->input(c->num_inputs() - 3), 0, &unused)); |
607 | TF_RETURN_IF_ERROR( |
608 | c->WithRank(c->input(c->num_inputs() - 2), 0, &unused)); |
609 | TF_RETURN_IF_ERROR( |
610 | c->WithRank(c->input(c->num_inputs() - 1), 0, &unused)); |
611 | |
612 | return shape_inference::ScalarShape(c); |
613 | }); |
614 | |
615 | REGISTER_OP("ExperimentalMapAndBatchDataset" ) |
616 | .Input("input_dataset: variant" ) |
617 | .Input("other_arguments: Targuments" ) |
618 | .Input("batch_size: int64" ) |
619 | .Input("num_parallel_calls: int64" ) |
620 | .Input("drop_remainder: bool" ) |
621 | .Output("handle: variant" ) |
622 | .Attr("f: func" ) |
623 | .Attr("Targuments: list(type) >= 0" ) |
624 | .Attr("output_types: list(type) >= 1" ) |
625 | .Attr("output_shapes: list(shape) >= 1" ) |
626 | .Attr("preserve_cardinality: bool = false" ) |
627 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
628 | "output_types" )) |
629 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
630 | // Use index from the end to retrieve the Input shapes, |
631 | // so that to avoid guessing the length of "other_arguments". |
632 | // batch_size, num_parallel_calls, and drop_remainder are 0-D scalars. |
633 | shape_inference::ShapeHandle unused; |
634 | TF_RETURN_IF_ERROR( |
635 | c->WithRank(c->input(c->num_inputs() - 3), 0, &unused)); |
636 | TF_RETURN_IF_ERROR( |
637 | c->WithRank(c->input(c->num_inputs() - 2), 0, &unused)); |
638 | TF_RETURN_IF_ERROR( |
639 | c->WithRank(c->input(c->num_inputs() - 1), 0, &unused)); |
640 | |
641 | return shape_inference::ScalarShape(c); |
642 | }); |
643 | |
644 | REGISTER_OP("ExperimentalMapDataset" ) |
645 | .Input("input_dataset: variant" ) |
646 | .Input("other_arguments: Targuments" ) |
647 | .Output("handle: variant" ) |
648 | .Attr("f: func" ) |
649 | .Attr("Targuments: list(type) >= 0" ) |
650 | .Attr("output_types: list(type) >= 1" ) |
651 | .Attr("output_shapes: list(shape) >= 1" ) |
652 | .Attr("use_inter_op_parallelism: bool = true" ) |
653 | .Attr("preserve_cardinality: bool = false" ) |
654 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
655 | "output_types" )) |
656 | .SetShapeFn(shape_inference::ScalarShape); |
657 | |
658 | REGISTER_OP("MatchingFilesDataset" ) |
659 | .Input("patterns: string" ) |
660 | .Output("handle: variant" ) |
661 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
662 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET, |
663 | TFT_STRING)) |
664 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
665 | shape_inference::ShapeHandle unused; |
666 | // `patterns` must be a scalar or a vector. |
667 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); |
668 | return shape_inference::ScalarShape(c); |
669 | }); |
670 | |
671 | REGISTER_OP("ExperimentalMatchingFilesDataset" ) |
672 | .Input("patterns: string" ) |
673 | .Output("handle: variant" ) |
674 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
675 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET, |
676 | TFT_STRING)) |
677 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
678 | shape_inference::ShapeHandle unused; |
679 | // `patterns` must be a scalar or a vector. |
680 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); |
681 | return shape_inference::ScalarShape(c); |
682 | }); |
683 | |
684 | REGISTER_OP("MaxIntraOpParallelismDataset" ) |
685 | .Input("input_dataset: variant" ) |
686 | .Input("max_intra_op_parallelism: int64" ) |
687 | .Output("handle: variant" ) |
688 | .Attr("output_types: list(type) >= 1" ) |
689 | .Attr("output_shapes: list(shape) >= 1" ) |
690 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
691 | "output_types" )) |
692 | .SetShapeFn(shape_inference::ScalarShape); |
693 | |
694 | REGISTER_OP("ExperimentalMaxIntraOpParallelismDataset" ) |
695 | .Input("input_dataset: variant" ) |
696 | .Input("max_intra_op_parallelism: int64" ) |
697 | .Output("handle: variant" ) |
698 | .Attr("output_types: list(type) >= 1" ) |
699 | .Attr("output_shapes: list(shape) >= 1" ) |
700 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
701 | "output_types" )) |
702 | .SetShapeFn(shape_inference::ScalarShape); |
703 | |
704 | REGISTER_OP("NonSerializableDataset" ) |
705 | .Input("input_dataset: variant" ) |
706 | .Output("handle: variant" ) |
707 | .Attr("output_types: list(type) >= 1" ) |
708 | .Attr("output_shapes: list(shape) >= 1" ) |
709 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
710 | "output_types" )) |
711 | .SetShapeFn(shape_inference::ScalarShape); |
712 | |
713 | REGISTER_OP("ExperimentalNonSerializableDataset" ) |
714 | .Input("input_dataset: variant" ) |
715 | .Output("handle: variant" ) |
716 | .Attr("output_types: list(type) >= 1" ) |
717 | .Attr("output_shapes: list(shape) >= 1" ) |
718 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
719 | "output_types" )) |
720 | .SetShapeFn(shape_inference::ScalarShape); |
721 | |
722 | REGISTER_OP("ParallelInterleaveDataset" ) |
723 | .Input("input_dataset: variant" ) |
724 | .Input("other_arguments: Targuments" ) |
725 | .Input("cycle_length: int64" ) |
726 | .Input("block_length: int64" ) |
727 | .Input("sloppy: bool" ) |
728 | .Input("buffer_output_elements: int64" ) |
729 | .Input("prefetch_input_elements: int64" ) |
730 | .Output("handle: variant" ) |
731 | .Attr("f: func" ) |
732 | .Attr("Targuments: list(type) >= 0" ) |
733 | .Attr("output_types: list(type) >= 1" ) |
734 | .Attr("output_shapes: list(shape) >= 1" ) |
735 | .Attr("metadata: string = ''" ) |
736 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
737 | "output_types" )) |
738 | .SetShapeFn(shape_inference::ScalarShape); |
739 | |
740 | // This is the V2 of ParallelInterleaveDataset, renamed to differentiate it |
741 | // from the non-experimental ParallelInterleaveDataset op. |
742 | REGISTER_OP("LegacyParallelInterleaveDatasetV2" ) |
743 | .Input("input_dataset: variant" ) |
744 | .Input("other_arguments: Targuments" ) |
745 | .Input("cycle_length: int64" ) |
746 | .Input("block_length: int64" ) |
747 | .Input("buffer_output_elements: int64" ) |
748 | .Input("prefetch_input_elements: int64" ) |
749 | .Output("handle: variant" ) |
750 | .Attr("f: func" ) |
751 | // "true", "false", or "default". |
752 | .Attr("deterministic: string = 'default'" ) |
753 | .Attr("Targuments: list(type) >= 0" ) |
754 | .Attr("output_types: list(type) >= 1" ) |
755 | .Attr("output_shapes: list(shape) >= 1" ) |
756 | .Attr("metadata: string = ''" ) |
757 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
758 | "output_types" )) |
759 | .SetShapeFn(shape_inference::ScalarShape); |
760 | |
761 | // This op is no longer used. We keep it so that we can read graphs written by |
762 | // old versions of TensorFlow. |
763 | REGISTER_OP("ExperimentalParallelInterleaveDataset" ) |
764 | .Input("input_dataset: variant" ) |
765 | .Input("other_arguments: Targuments" ) |
766 | .Input("cycle_length: int64" ) |
767 | .Input("block_length: int64" ) |
768 | .Input("sloppy: bool" ) |
769 | .Input("buffer_output_elements: int64" ) |
770 | .Input("prefetch_input_elements: int64" ) |
771 | .Output("handle: variant" ) |
772 | .Attr("f: func" ) |
773 | .Attr("Targuments: list(type) >= 0" ) |
774 | .Attr("output_types: list(type) >= 1" ) |
775 | .Attr("output_shapes: list(shape) >= 1" ) |
776 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
777 | "output_types" )) |
778 | .SetShapeFn(shape_inference::ScalarShape); |
779 | |
780 | REGISTER_OP("ParseExampleDataset" ) |
781 | .Input("input_dataset: variant" ) |
782 | .Input("num_parallel_calls: int64" ) |
783 | .Input("dense_defaults: Tdense" ) |
784 | .Output("handle: variant" ) |
785 | .Attr("sparse_keys: list(string) >= 0" ) |
786 | .Attr("dense_keys: list(string) >= 0" ) |
787 | .Attr("sparse_types: list({float,int64,string}) >= 0" ) |
788 | .Attr("Tdense: list({float,int64,string}) >= 0" ) |
789 | .Attr("dense_shapes: list(shape) >= 0" ) |
790 | .Attr("output_types: list(type) >= 1" ) |
791 | .Attr("output_shapes: list(shape) >= 1" ) // Output components will be |
792 | // sorted by key (dense_keys and |
793 | // sparse_keys combined) here. |
794 | .Attr("sloppy: bool = false" ) |
795 | .Attr("ragged_keys: list(string) >= 0 = []" ) |
796 | .Attr("ragged_value_types: list({float,int64,string}) >= 0 = []" ) |
797 | .Attr("ragged_split_types: list({int32,int64}) >= 0 = []" ) |
798 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
799 | "output_types" )) |
800 | .SetShapeFn(shape_inference::ScalarShape); |
801 | |
802 | REGISTER_OP("ParseExampleDatasetV2" ) |
803 | .Input("input_dataset: variant" ) |
804 | .Input("num_parallel_calls: int64" ) |
805 | .Input("dense_defaults: Tdense" ) |
806 | .Output("handle: variant" ) |
807 | .Attr("sparse_keys: list(string) >= 0" ) |
808 | .Attr("dense_keys: list(string) >= 0" ) |
809 | .Attr("sparse_types: list({float,int64,string}) >= 0" ) |
810 | .Attr("Tdense: list({float,int64,string}) >= 0" ) |
811 | .Attr("dense_shapes: list(shape) >= 0" ) |
812 | .Attr("output_types: list(type) >= 1" ) |
813 | .Attr("output_shapes: list(shape) >= 1" ) // Output components will be |
814 | // sorted by key (dense_keys and |
815 | // sparse_keys combined) here. |
816 | // "true", "false", or "default". |
817 | .Attr("deterministic: string = 'default'" ) |
818 | .Attr("ragged_keys: list(string) >= 0 = []" ) |
819 | .Attr("ragged_value_types: list({float,int64,string}) >= 0 = []" ) |
820 | .Attr("ragged_split_types: list({int32,int64}) >= 0 = []" ) |
821 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
822 | "output_types" )) |
823 | .SetShapeFn(shape_inference::ScalarShape); |
824 | |
825 | REGISTER_OP("ExperimentalParseExampleDataset" ) |
826 | .Input("input_dataset: variant" ) |
827 | .Input("num_parallel_calls: int64" ) |
828 | .Input("dense_defaults: Tdense" ) |
829 | .Output("handle: variant" ) |
830 | .Attr("sparse_keys: list(string) >= 0" ) |
831 | .Attr("dense_keys: list(string) >= 0" ) |
832 | .Attr("sparse_types: list({float,int64,string}) >= 0" ) |
833 | .Attr("Tdense: list({float,int64,string}) >= 0" ) |
834 | .Attr("dense_shapes: list(shape) >= 0" ) |
835 | .Attr("output_types: list(type) >= 1" ) |
836 | .Attr("output_shapes: list(shape) >= 1" ) // Output components will be |
837 | // sorted by key (dense_keys and |
838 | // sparse_keys combined) here. |
839 | .Attr("sloppy: bool = false" ) |
840 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
841 | "output_types" )) |
842 | .SetShapeFn(shape_inference::ScalarShape); |
843 | |
844 | REGISTER_OP("PrivateThreadPoolDataset" ) |
845 | .Input("input_dataset: variant" ) |
846 | .Input("num_threads: int64" ) |
847 | .Output("handle: variant" ) |
848 | .Attr("output_types: list(type) >= 1" ) |
849 | .Attr("output_shapes: list(shape) >= 1" ) |
850 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
851 | "output_types" )) |
852 | .SetShapeFn(shape_inference::ScalarShape); |
853 | |
854 | REGISTER_OP("ExperimentalPrivateThreadPoolDataset" ) |
855 | .Input("input_dataset: variant" ) |
856 | .Input("num_threads: int64" ) |
857 | .Output("handle: variant" ) |
858 | .Attr("output_types: list(type) >= 1" ) |
859 | .Attr("output_shapes: list(shape) >= 1" ) |
860 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
861 | "output_types" )) |
862 | .SetShapeFn(shape_inference::ScalarShape); |
863 | |
864 | REGISTER_OP("ExperimentalRandomDataset" ) |
865 | .Input("seed: int64" ) |
866 | .Input("seed2: int64" ) |
867 | .Output("handle: variant" ) |
868 | .Attr("output_types: list(type) >= 1" ) |
869 | .Attr("output_shapes: list(shape) >= 1" ) |
870 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
871 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
872 | "output_types" )) |
873 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
874 | shape_inference::ShapeHandle unused; |
875 | // buffer_size, seed, and seed2 should be scalars. |
876 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
877 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
878 | return shape_inference::ScalarShape(c); |
879 | }); |
880 | |
881 | REGISTER_OP("RandomDataset" ) |
882 | .Input("seed: int64" ) |
883 | .Input("seed2: int64" ) |
884 | .Output("handle: variant" ) |
885 | .Attr("output_types: list(type) >= 1" ) |
886 | .Attr("output_shapes: list(shape) >= 1" ) |
887 | .Attr("metadata: string = ''" ) |
888 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
889 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
890 | "output_types" )) |
891 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
892 | shape_inference::ShapeHandle unused; |
893 | // buffer_size, seed, and seed2 should be scalars. |
894 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
895 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
896 | return shape_inference::ScalarShape(c); |
897 | }); |
898 | |
899 | REGISTER_OP("ExperimentalRebatchDataset" ) |
900 | .Input("input_dataset: variant" ) |
901 | .Input("num_replicas: int64" ) |
902 | .Output("handle: variant" ) |
903 | .Attr("output_types: list(type) >= 1" ) |
904 | .Attr("output_shapes: list(shape) >= 1" ) |
905 | .Attr("use_fallback: bool = true" ) |
906 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
907 | "output_types" )) |
908 | .SetShapeFn(shape_inference::ScalarShape); |
909 | |
910 | REGISTER_OP("RebatchDataset" ) |
911 | .Input("input_dataset: variant" ) |
912 | .Input("num_replicas: int64" ) |
913 | .Output("handle: variant" ) |
914 | .Attr("output_types: list(type) >= 1" ) |
915 | .Attr("output_shapes: list(shape) >= 1" ) |
916 | .Attr("use_fallback: bool = true" ) |
917 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
918 | "output_types" )) |
919 | .SetForwardTypeFn(full_type::ContainerMap(TFT_DATASET, /*input_idx=*/0, |
920 | full_type::BatchTensor)) |
921 | .SetShapeFn(shape_inference::ScalarShape); |
922 | |
923 | REGISTER_OP("RebatchDatasetV2" ) |
924 | .Input("input_dataset: variant" ) |
925 | .Input("batch_sizes: int64" ) |
926 | .Input("drop_remainder: bool" ) |
927 | .Output("handle: variant" ) |
928 | .Attr("output_types: list(type) >= 1" ) |
929 | .Attr("output_shapes: list(shape) >= 1" ) |
930 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
931 | "output_types" )) |
932 | .SetForwardTypeFn(full_type::ContainerMap(TFT_DATASET, /*input_idx=*/0, |
933 | full_type::BatchTensor)) |
934 | .SetShapeFn(shape_inference::ScalarShape); |
935 | |
936 | REGISTER_OP("SamplingDataset" ) |
937 | .Input("input_dataset: variant" ) |
938 | .Input("rate: float32" ) |
939 | .Input("seed: int64" ) |
940 | .Input("seed2: int64" ) |
941 | .Output("handle: variant" ) |
942 | .Attr("output_types: list(type) >= 1" ) |
943 | .Attr("output_shapes: list(shape) >= 1" ) |
944 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
945 | "output_types" )) |
946 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
947 | shape_inference::ShapeHandle unused; |
948 | // rate, seed, and seed2 should be scalars. |
949 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
950 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
951 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
952 | return shape_inference::ScalarShape(c); |
953 | }); |
954 | |
955 | REGISTER_OP("ScanDataset" ) |
956 | .Input("input_dataset: variant" ) |
957 | .Input("initial_state: Tstate" ) |
958 | .Input("other_arguments: Targuments" ) |
959 | .Output("handle: variant" ) |
960 | .Attr("f: func" ) |
961 | .Attr("Tstate: list(type) >= 1" ) |
962 | .Attr("Targuments: list(type) >= 0" ) |
963 | .Attr("output_types: list(type) >= 1" ) |
964 | .Attr("output_shapes: list(shape) >= 1" ) |
965 | .Attr("preserve_cardinality: bool = false" ) |
966 | .Attr("use_default_device: bool = true" ) |
967 | .Attr("metadata: string = ''" ) |
968 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
969 | "output_types" )) |
970 | .SetShapeFn(shape_inference::ScalarShape); |
971 | |
972 | REGISTER_OP("ExperimentalScanDataset" ) |
973 | .Input("input_dataset: variant" ) |
974 | .Input("initial_state: Tstate" ) |
975 | .Input("other_arguments: Targuments" ) |
976 | .Output("handle: variant" ) |
977 | .Attr("f: func" ) |
978 | .Attr("Tstate: list(type) >= 1" ) |
979 | .Attr("Targuments: list(type) >= 0" ) |
980 | .Attr("output_types: list(type) >= 1" ) |
981 | .Attr("output_shapes: list(shape) >= 1" ) |
982 | .Attr("preserve_cardinality: bool = false" ) |
983 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
984 | "output_types" )) |
985 | .SetShapeFn(shape_inference::ScalarShape); |
986 | |
987 | REGISTER_OP("SetStatsAggregatorDataset" ) |
988 | .Input("input_dataset: variant" ) |
989 | .Input("stats_aggregator: resource" ) |
990 | .Input("tag: string" ) |
991 | .Input("counter_prefix: string" ) |
992 | .Output("handle: variant" ) |
993 | .Attr("output_types: list(type) >= 1" ) |
994 | .Attr("output_shapes: list(shape) >= 1" ) |
995 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
996 | "output_types" )) |
997 | .SetShapeFn(shape_inference::ScalarShape); |
998 | |
999 | REGISTER_OP("ExperimentalSetStatsAggregatorDataset" ) |
1000 | .Input("input_dataset: variant" ) |
1001 | .Input("stats_aggregator: resource" ) |
1002 | .Input("tag: string" ) |
1003 | .Input("counter_prefix: string" ) |
1004 | .Output("handle: variant" ) |
1005 | .Attr("output_types: list(type) >= 1" ) |
1006 | .Attr("output_shapes: list(shape) >= 1" ) |
1007 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1008 | "output_types" )) |
1009 | .SetShapeFn(shape_inference::ScalarShape); |
1010 | |
1011 | REGISTER_OP("SleepDataset" ) |
1012 | .Input("input_dataset: variant" ) |
1013 | .Input("sleep_microseconds: int64" ) |
1014 | .Output("handle: variant" ) |
1015 | .Attr("output_types: list(type) >= 1" ) |
1016 | .Attr("output_shapes: list(shape) >= 1" ) |
1017 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1018 | "output_types" )) |
1019 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1020 | shape_inference::ShapeHandle unused; |
1021 | // Both inputs are scalar. |
1022 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused)); |
1023 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 0, &unused)); |
1024 | return shape_inference::ScalarShape(c); |
1025 | }); |
1026 | |
1027 | REGISTER_OP("ExperimentalSleepDataset" ) |
1028 | .Input("input_dataset: variant" ) |
1029 | .Input("sleep_microseconds: int64" ) |
1030 | .Output("handle: variant" ) |
1031 | .Attr("output_types: list(type) >= 1" ) |
1032 | .Attr("output_shapes: list(shape) >= 1" ) |
1033 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1034 | "output_types" )) |
1035 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1036 | shape_inference::ShapeHandle unused; |
1037 | // Both inputs are scalar. |
1038 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused)); |
1039 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 0, &unused)); |
1040 | return shape_inference::ScalarShape(c); |
1041 | }); |
1042 | |
1043 | REGISTER_OP("SlidingWindowDataset" ) |
1044 | .Input("input_dataset: variant" ) |
1045 | .Input("window_size: int64" ) |
1046 | .Input("window_shift: int64" ) |
1047 | .Input("window_stride: int64" ) |
1048 | .Output("handle: variant" ) |
1049 | .Attr("drop_remainder: bool = true" ) |
1050 | .Attr("output_types: list(type) >= 1" ) |
1051 | .Attr("output_shapes: list(shape) >= 1" ) |
1052 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1053 | "output_types" )) |
1054 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1055 | shape_inference::ShapeHandle unused; |
1056 | // window_size, window_shift, and window_stride should be scalars. |
1057 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1058 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
1059 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1060 | return shape_inference::ScalarShape(c); |
1061 | }); |
1062 | |
1063 | REGISTER_OP("ExperimentalSlidingWindowDataset" ) |
1064 | .Input("input_dataset: variant" ) |
1065 | .Input("window_size: int64" ) |
1066 | .Input("window_shift: int64" ) |
1067 | .Input("window_stride: int64" ) |
1068 | .Output("handle: variant" ) |
1069 | .Attr("output_types: list(type) >= 1" ) |
1070 | .Attr("output_shapes: list(shape) >= 1" ) |
1071 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1072 | "output_types" )) |
1073 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1074 | shape_inference::ShapeHandle unused; |
1075 | // window_size, window_shift, and window_stride should be scalars. |
1076 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1077 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
1078 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1079 | return shape_inference::ScalarShape(c); |
1080 | }); |
1081 | |
1082 | REGISTER_OP("SnapshotDataset" ) |
1083 | .Input("input_dataset: variant" ) |
1084 | .Input("path: string" ) |
1085 | .Output("handle: variant" ) |
1086 | .Attr("output_types: list(type) >= 1" ) |
1087 | .Attr("output_shapes: list(shape) >= 1" ) |
1088 | .Attr("compression: string = ''" ) |
1089 | .Attr("reader_path_prefix: string = ''" ) |
1090 | .Attr("writer_path_prefix: string = ''" ) |
1091 | .Attr("shard_size_bytes: int = 10737418240" ) // 10 GiB default |
1092 | .Attr("pending_snapshot_expiry_seconds: int = 86400" ) // 1 day default |
1093 | .Attr("num_reader_threads: int = 1" ) |
1094 | .Attr("reader_buffer_size: int = 1" ) |
1095 | .Attr("num_writer_threads: int = 1" ) |
1096 | .Attr("writer_buffer_size: int = 1" ) |
1097 | .Attr("shuffle_on_read: bool = false" ) |
1098 | .Attr("seed: int = 0" ) |
1099 | .Attr("seed2: int = 0" ) |
1100 | .Attr("mode: string = 'auto'" ) |
1101 | .Attr("snapshot_name: string = ''" ) |
1102 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1103 | "output_types" )) |
1104 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1105 | shape_inference::ShapeHandle unused; |
1106 | // snapshot_path should be a scalar. |
1107 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1108 | return shape_inference::ScalarShape(c); |
1109 | }); |
1110 | |
1111 | REGISTER_OP("SnapshotDatasetV2" ) |
1112 | .Input("input_dataset: variant" ) |
1113 | .Input("path: string" ) |
1114 | .Input("reader_func_other_args: Treader_func_args" ) |
1115 | .Input("shard_func_other_args: Tshard_func_args" ) |
1116 | .Output("handle: variant" ) |
1117 | .Attr("output_types: list(type) >= 1" ) |
1118 | .Attr("output_shapes: list(shape) >= 1" ) |
1119 | .Attr("compression: string = ''" ) |
1120 | .Attr("reader_prefix: string = ''" ) |
1121 | .Attr("writer_prefix: string = ''" ) |
1122 | .Attr("hash_valid: bool = false" ) |
1123 | .Attr("hash: int = 0" ) |
1124 | .Attr("reader_func: func" ) |
1125 | .Attr("shard_func: func" ) |
1126 | .Attr("Treader_func_args: list(type) >= 0" ) |
1127 | .Attr("Tshard_func_args: list(type) >= 0" ) |
1128 | .Attr("metadata: string = ''" ) |
1129 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1130 | "output_types" )) |
1131 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1132 | shape_inference::ShapeHandle unused; |
1133 | // `path` should be a scalar. |
1134 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1135 | return shape_inference::ScalarShape(c); |
1136 | }); |
1137 | |
1138 | REGISTER_OP("SaveDataset" ) |
1139 | .Input("input_dataset: variant" ) |
1140 | .Input("path: string" ) |
1141 | .Input("shard_func_other_args: Tshard_func_args" ) |
1142 | .Attr("compression: string = ''" ) |
1143 | .Attr("shard_func: func" ) |
1144 | .Attr("use_shard_func: bool = true" ) |
1145 | .Attr("Tshard_func_args: list(type) >= 0" ) |
1146 | .SetIsStateful() |
1147 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1148 | shape_inference::ShapeHandle unused; |
1149 | // `path` should be a scalar. |
1150 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1151 | return OkStatus(); |
1152 | }); |
1153 | |
1154 | REGISTER_OP("SaveDatasetV2" ) |
1155 | .Input("input_dataset: variant" ) |
1156 | .Input("path: string" ) |
1157 | .Input("shard_func_other_args: Tshard_func_args" ) |
1158 | .Output("handle: variant" ) |
1159 | .Attr("compression: string = ''" ) |
1160 | .Attr("shard_func: func" ) |
1161 | .Attr("use_shard_func: bool = true" ) |
1162 | .Attr("Tshard_func_args: list(type) >= 0" ) |
1163 | .Attr("output_types: list(type) >= 1" ) |
1164 | .Attr("output_shapes: list(shape) >= 1" ) |
1165 | .SetIsStateful() |
1166 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1167 | shape_inference::ShapeHandle unused; |
1168 | // `path` should be a scalar. |
1169 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1170 | return shape_inference::ScalarShape(c); |
1171 | }); |
1172 | |
1173 | REGISTER_OP("LoadDataset" ) |
1174 | .Input("path: string" ) |
1175 | .Input("reader_func_other_args: Treader_func_args" ) |
1176 | .Output("handle: variant" ) |
1177 | .Attr("output_types: list(type) >= 1" ) |
1178 | .Attr("output_shapes: list(shape) >= 1" ) |
1179 | .Attr("compression: string = ''" ) |
1180 | .Attr("reader_func: func" ) |
1181 | .Attr("Treader_func_args: list(type) >= 0" ) |
1182 | .SetIsStateful() |
1183 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1184 | "output_types" )) |
1185 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1186 | shape_inference::ShapeHandle unused; |
1187 | // `path` should be a scalar. |
1188 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
1189 | return shape_inference::ScalarShape(c); |
1190 | }); |
1191 | |
1192 | REGISTER_OP("SnapshotDatasetReader" ) |
1193 | .Input("shard_dir: string" ) |
1194 | .Input("start_index: int64" ) |
1195 | .Output("handle: variant" ) |
1196 | .Attr("output_types: list(type) >= 1" ) |
1197 | .Attr("output_shapes: list(shape) >= 1" ) |
1198 | .Attr("compression: string = ''" ) |
1199 | .Attr("version: int" ) |
1200 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1201 | "output_types" )) |
1202 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1203 | shape_inference::ShapeHandle unused; |
1204 | // `shard_dir` should be a scalar. |
1205 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
1206 | // `start_index` should be a scalar. |
1207 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1208 | return shape_inference::ScalarShape(c); |
1209 | }); |
1210 | |
1211 | REGISTER_OP("SnapshotNestedDatasetReader" ) |
1212 | .Input("inputs: N * variant" ) |
1213 | .Output("handle: variant" ) |
1214 | .Attr("output_types: list(type) >= 1" ) |
1215 | .Attr("output_shapes: list(shape) >= 1" ) |
1216 | .Attr("N: int >= 1" ) |
1217 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1218 | "output_types" )) |
1219 | .SetShapeFn(shape_inference::ScalarShape); |
1220 | |
1221 | REGISTER_OP("SqlDataset" ) |
1222 | .Input("driver_name: string" ) |
1223 | .Input("data_source_name: string" ) |
1224 | .Input("query: string" ) |
1225 | .Output("handle: variant" ) |
1226 | .Attr("output_types: list(type) >= 1" ) |
1227 | .Attr("output_shapes: list(shape) >= 1" ) |
1228 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
1229 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1230 | "output_types" )) |
1231 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1232 | shape_inference::ShapeHandle unused; |
1233 | // driver_name, data_source_name, and query should be scalars. |
1234 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
1235 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1236 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
1237 | return shape_inference::ScalarShape(c); |
1238 | }); |
1239 | |
1240 | REGISTER_OP("ExperimentalSqlDataset" ) |
1241 | .Input("driver_name: string" ) |
1242 | .Input("data_source_name: string" ) |
1243 | .Input("query: string" ) |
1244 | .Output("handle: variant" ) |
1245 | .Attr("output_types: list(type) >= 1" ) |
1246 | .Attr("output_shapes: list(shape) >= 1" ) |
1247 | .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. |
1248 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1249 | "output_types" )) |
1250 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1251 | shape_inference::ShapeHandle unused; |
1252 | // driver_name, data_source_name, and query should be scalars. |
1253 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
1254 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1255 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
1256 | return shape_inference::ScalarShape(c); |
1257 | }); |
1258 | |
1259 | REGISTER_OP("StatsAggregatorHandle" ) |
1260 | .Output("handle: resource" ) |
1261 | .SetShapeFn(shape_inference::ScalarShape) |
1262 | .Attr("container: string = ''" ) |
1263 | .Attr("shared_name: string = ''" ); |
1264 | |
1265 | REGISTER_OP("ExperimentalStatsAggregatorHandle" ) |
1266 | .Output("handle: resource" ) |
1267 | .SetShapeFn(shape_inference::ScalarShape) |
1268 | .Attr("container: string = ''" ) |
1269 | .Attr("shared_name: string = ''" ); |
1270 | |
1271 | REGISTER_OP("StatsAggregatorHandleV2" ) |
1272 | .Output("handle: resource" ) |
1273 | .SetShapeFn(shape_inference::ScalarShape) |
1274 | .Attr("container: string = ''" ) |
1275 | .Attr("shared_name: string = ''" ); |
1276 | |
1277 | REGISTER_OP("StatsAggregatorSetSummaryWriter" ) |
1278 | .Input("stats_aggregator: resource" ) |
1279 | .Input("summary: resource" ) |
1280 | .SetShapeFn(shape_inference::NoOutputs); |
1281 | |
1282 | REGISTER_OP("StatsAggregatorSummary" ) |
1283 | .Input("iterator: resource" ) |
1284 | .Output("summary: string" ) |
1285 | .SetShapeFn(shape_inference::ScalarShape); |
1286 | |
1287 | REGISTER_OP("ExperimentalStatsAggregatorSummary" ) |
1288 | .Input("iterator: resource" ) |
1289 | .Output("summary: string" ) |
1290 | .SetShapeFn(shape_inference::ScalarShape); |
1291 | |
1292 | REGISTER_OP("TakeWhileDataset" ) |
1293 | .Input("input_dataset: variant" ) |
1294 | .Input("other_arguments: Targuments" ) |
1295 | .Output("handle: variant" ) |
1296 | .Attr("predicate: func" ) |
1297 | .Attr("Targuments: list(type) >= 0" ) |
1298 | .Attr("output_types: list(type) >= 1" ) |
1299 | .Attr("output_shapes: list(shape) >= 1" ) |
1300 | .Attr("metadata: string = ''" ) |
1301 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1302 | "output_types" )) |
1303 | .SetShapeFn(shape_inference::ScalarShape); |
1304 | |
1305 | REGISTER_OP("ExperimentalTakeWhileDataset" ) |
1306 | .Input("input_dataset: variant" ) |
1307 | .Input("other_arguments: Targuments" ) |
1308 | .Output("handle: variant" ) |
1309 | .Attr("predicate: func" ) |
1310 | .Attr("Targuments: list(type) >= 0" ) |
1311 | .Attr("output_types: list(type) >= 1" ) |
1312 | .Attr("output_shapes: list(shape) >= 1" ) |
1313 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1314 | "output_types" )) |
1315 | .SetShapeFn(shape_inference::ScalarShape); |
1316 | |
1317 | REGISTER_OP("ThreadPoolDataset" ) |
1318 | .Input("input_dataset: variant" ) |
1319 | .Input("thread_pool: resource" ) |
1320 | .Output("handle: variant" ) |
1321 | .Attr("output_types: list(type) >= 1" ) |
1322 | .Attr("output_shapes: list(shape) >= 1" ) |
1323 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1324 | "output_types" )) |
1325 | .SetShapeFn(shape_inference::ScalarShape); |
1326 | |
1327 | REGISTER_OP("ExperimentalThreadPoolDataset" ) |
1328 | .Input("input_dataset: variant" ) |
1329 | .Input("thread_pool: resource" ) |
1330 | .Output("handle: variant" ) |
1331 | .Attr("output_types: list(type) >= 1" ) |
1332 | .Attr("output_shapes: list(shape) >= 1" ) |
1333 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1334 | "output_types" )) |
1335 | .SetShapeFn(shape_inference::ScalarShape); |
1336 | |
1337 | REGISTER_OP("ThreadPoolHandle" ) |
1338 | .Output("handle: resource" ) |
1339 | .SetShapeFn(shape_inference::ScalarShape) |
1340 | .Attr("num_threads: int" ) |
1341 | .Attr("max_intra_op_parallelism: int = 1" ) |
1342 | .Attr("display_name: string" ) |
1343 | .Attr("container: string = ''" ) |
1344 | .Attr("shared_name: string = ''" ); |
1345 | |
1346 | REGISTER_OP("ExperimentalThreadPoolHandle" ) |
1347 | .Output("handle: resource" ) |
1348 | .SetShapeFn(shape_inference::ScalarShape) |
1349 | .Attr("num_threads: int" ) |
1350 | .Attr("max_intra_op_parallelism: int = 1" ) |
1351 | .Attr("display_name: string" ) |
1352 | .Attr("container: string = ''" ) |
1353 | .Attr("shared_name: string = ''" ); |
1354 | |
1355 | REGISTER_OP("UnbatchDataset" ) |
1356 | .Input("input_dataset: variant" ) |
1357 | .Output("handle: variant" ) |
1358 | .Attr("output_types: list(type) >= 1" ) |
1359 | .Attr("output_shapes: list(shape) >= 1" ) |
1360 | .Attr("metadata: string = ''" ) |
1361 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1362 | "output_types" )) |
1363 | .SetShapeFn(shape_inference::ScalarShape); |
1364 | |
1365 | REGISTER_OP("ExperimentalUnbatchDataset" ) |
1366 | .Input("input_dataset: variant" ) |
1367 | .Output("handle: variant" ) |
1368 | .Attr("output_types: list(type) >= 1" ) |
1369 | .Attr("output_shapes: list(shape) >= 1" ) |
1370 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1371 | "output_types" )) |
1372 | .SetShapeFn(shape_inference::ScalarShape); |
1373 | |
1374 | REGISTER_OP("UniqueDataset" ) |
1375 | .Input("input_dataset: variant" ) |
1376 | .Output("handle: variant" ) |
1377 | .Attr("output_types: list(type) >= 1" ) |
1378 | .Attr("output_shapes: list(shape) >= 1" ) |
1379 | .Attr("metadata: string = ''" ) |
1380 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1381 | "output_types" )) |
1382 | .SetShapeFn(shape_inference::ScalarShape); |
1383 | |
1384 | REGISTER_OP("ExperimentalUniqueDataset" ) |
1385 | .Input("input_dataset: variant" ) |
1386 | .Output("handle: variant" ) |
1387 | .Attr("output_types: list(type) >= 1" ) |
1388 | .Attr("output_shapes: list(shape) >= 1" ) |
1389 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1390 | "output_types" )) |
1391 | .SetShapeFn(shape_inference::ScalarShape); |
1392 | |
1393 | REGISTER_OP("DummyIterationCounter" ) |
1394 | .Output("handle: resource" ) |
1395 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1396 | c->set_output(0, c->Scalar()); |
1397 | return OkStatus(); |
1398 | }); |
1399 | |
1400 | REGISTER_OP("DataServiceDataset" ) |
1401 | .Input("dataset_id: int64" ) |
1402 | .Input("processing_mode: string" ) |
1403 | .Input("address: string" ) |
1404 | .Input("protocol: string" ) |
1405 | .Input("job_name: string" ) |
1406 | .Input("max_outstanding_requests: int64" ) |
1407 | .Input("iteration_counter: resource" ) |
1408 | .Output("handle: variant" ) |
1409 | .Attr("task_refresh_interval_hint_ms: int = -1" ) |
1410 | .Attr("output_types: list(type) >= 1" ) |
1411 | .Attr("output_shapes: list(shape) >= 1" ) |
1412 | .Attr("data_transfer_protocol: string = ''" ) |
1413 | .Attr("target_workers: string = 'AUTO'" ) |
1414 | .Attr("cross_trainer_cache_options: string = ''" ) |
1415 | .SetIsStateful() |
1416 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1417 | "output_types" )) |
1418 | .SetShapeFn(shape_inference::ScalarShape); |
1419 | |
1420 | // Adds `consumer_index` and `num_consumers` arguments to support round-robin |
1421 | // reads. |
1422 | REGISTER_OP("DataServiceDatasetV2" ) |
1423 | .Input("dataset_id: int64" ) |
1424 | .Input("processing_mode: string" ) |
1425 | .Input("address: string" ) |
1426 | .Input("protocol: string" ) |
1427 | .Input("job_name: string" ) |
1428 | .Input("consumer_index: int64" ) |
1429 | .Input("num_consumers: int64" ) |
1430 | .Input("max_outstanding_requests: int64" ) |
1431 | .Input("iteration_counter: resource" ) |
1432 | .Output("handle: variant" ) |
1433 | .Attr("task_refresh_interval_hint_ms: int = -1" ) |
1434 | .Attr("output_types: list(type) >= 1" ) |
1435 | .Attr("output_shapes: list(shape) >= 1" ) |
1436 | .Attr("data_transfer_protocol: string = ''" ) |
1437 | .Attr("target_workers: string = 'AUTO'" ) |
1438 | .Attr("cross_trainer_cache_options: string = ''" ) |
1439 | .SetIsStateful() |
1440 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1441 | "output_types" )) |
1442 | .SetShapeFn(shape_inference::ScalarShape); |
1443 | |
1444 | // Adds `uncompress` and `uncompress_fn` attributes to support uncompression. |
1445 | REGISTER_OP("DataServiceDatasetV3" ) |
1446 | .Input("dataset_id: int64" ) |
1447 | .Input("processing_mode: string" ) |
1448 | .Input("address: string" ) |
1449 | .Input("protocol: string" ) |
1450 | .Input("job_name: string" ) |
1451 | .Input("consumer_index: int64" ) |
1452 | .Input("num_consumers: int64" ) |
1453 | .Input("max_outstanding_requests: int64" ) |
1454 | .Input("iteration_counter: resource" ) |
1455 | .Output("handle: variant" ) |
1456 | .Attr("task_refresh_interval_hint_ms: int = -1" ) |
1457 | .Attr("output_types: list(type) >= 1" ) |
1458 | .Attr("output_shapes: list(shape) >= 1" ) |
1459 | .Attr("data_transfer_protocol: string = ''" ) |
1460 | .Attr("target_workers: string = 'AUTO'" ) |
1461 | .Attr("uncompress: bool = false" ) |
1462 | .Attr("uncompress_fn: func" ) |
1463 | .Attr("cross_trainer_cache_options: string = ''" ) |
1464 | .SetIsStateful() |
1465 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1466 | "output_types" )) |
1467 | .SetShapeFn(shape_inference::ScalarShape); |
1468 | |
1469 | // Changes `dataset_id` from int64 to string. |
1470 | REGISTER_OP("DataServiceDatasetV4" ) |
1471 | .Input("dataset_id: string" ) |
1472 | .Input("processing_mode: string" ) |
1473 | .Input("address: string" ) |
1474 | .Input("protocol: string" ) |
1475 | .Input("job_name: string" ) |
1476 | .Input("consumer_index: int64" ) |
1477 | .Input("num_consumers: int64" ) |
1478 | .Input("max_outstanding_requests: int64" ) |
1479 | .Input("iteration_counter: resource" ) |
1480 | .Output("handle: variant" ) |
1481 | .Attr("task_refresh_interval_hint_ms: int = -1" ) |
1482 | .Attr("output_types: list(type) >= 1" ) |
1483 | .Attr("output_shapes: list(shape) >= 1" ) |
1484 | .Attr("data_transfer_protocol: string = ''" ) |
1485 | .Attr("target_workers: string = 'AUTO'" ) |
1486 | .Attr("uncompress: bool = false" ) |
1487 | .Attr("uncompress_fn: func" ) |
1488 | .Attr("cross_trainer_cache_options: string = ''" ) |
1489 | .SetIsStateful() |
1490 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1491 | "output_types" )) |
1492 | .SetShapeFn(shape_inference::ScalarShape); |
1493 | |
1494 | REGISTER_OP("RegisterDataset" ) |
1495 | .Input("dataset: variant" ) |
1496 | .Input("address: string" ) |
1497 | .Input("protocol: string" ) |
1498 | .Output("dataset_id: int64" ) |
1499 | .Attr("external_state_policy: int" ) |
1500 | .Attr("element_spec: string = ''" ) |
1501 | .Attr("metadata: string = ''" ) |
1502 | .SetShapeFn(shape_inference::ScalarShape); |
1503 | |
1504 | // Changes `dataset_id` from int64 to string. |
1505 | REGISTER_OP("RegisterDatasetV2" ) |
1506 | .Input("dataset: variant" ) |
1507 | .Input("address: string" ) |
1508 | .Input("protocol: string" ) |
1509 | .Output("dataset_id: string" ) |
1510 | .Attr("external_state_policy: int" ) |
1511 | .Attr("element_spec: string = ''" ) |
1512 | .Attr("requested_dataset_id: string = ''" ) |
1513 | .Attr("metadata: string = ''" ) |
1514 | .SetShapeFn(shape_inference::ScalarShape); |
1515 | |
1516 | REGISTER_OP("InitializeTableFromDataset" ) |
1517 | .Input("table_handle: resource" ) |
1518 | .Input("dataset: variant" ) |
1519 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
1520 | shape_inference::ShapeHandle handle; |
1521 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); |
1522 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); |
1523 | return OkStatus(); |
1524 | }); |
1525 | |
1526 | // - `output_types` is the types of tensors in a single dataset element. |
1527 | // - `output_shapes` is the shapes of tensors in a single dataset element. |
1528 | // - `output_types` and `output_shapes` are the same size: the number of |
1529 | // tensors in a single dataset element, a.k.a. the number of components. |
1530 | // - `Tinput_types` is the types of tensors for all dataset elements. |
1531 | // `Tinput_types` is equivalent to `output_types` repeated for N total dataset |
1532 | // elements. |
1533 | REGISTER_OP("ListDataset" ) |
1534 | .Input("tensors: Tinput_types" ) |
1535 | .Output("handle: variant" ) |
1536 | .Attr("Tinput_types: list(type) >= 1" ) |
1537 | .Attr("output_types: list(type) >= 1" ) |
1538 | .Attr("output_shapes: list(shape) >= 1" ) |
1539 | .Attr("metadata: string = ''" ) |
1540 | .SetDoNotOptimize() |
1541 | .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, |
1542 | "output_types" )) |
1543 | .SetShapeFn(shape_inference::ScalarShape); |
1544 | |
1545 | } // namespace tensorflow |
1546 | |