1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/op_def_builder.h"
19#include "tensorflow/core/framework/shape_inference.h"
20
21namespace tensorflow {
22
23using shape_inference::DimensionHandle;
24using shape_inference::InferenceContext;
25using shape_inference::ShapeHandle;
26
27namespace {
28
29Status DequeueManyV2Shape(InferenceContext* c, ShapeHandle n_shape) {
30 auto* t = c->input_handle_shapes_and_types(0);
31 if (t != nullptr && t->size() == c->num_outputs()) {
32 for (int i = 0; i < c->num_outputs(); ++i) {
33 ShapeHandle combined_shape;
34 TF_RETURN_IF_ERROR(
35 c->Concatenate(n_shape, (*t)[i].shape, &combined_shape));
36 c->set_output(i, combined_shape);
37 }
38 return OkStatus();
39 } else {
40 return shape_inference::UnknownShape(c);
41 }
42}
43
44} // namespace
45
46// --------------------------------------------------------------------------
47
48REGISTER_OP("DynamicPartition")
49 .Input("data: T")
50 .Input("partitions: int32")
51 .Output("outputs: num_partitions * T")
52 .Attr("num_partitions: int")
53 .Attr("T: type")
54 .SetShapeFn([](InferenceContext* c) {
55 int64_t num_partitions;
56 TF_RETURN_IF_ERROR(c->GetAttr("num_partitions", &num_partitions));
57
58 ShapeHandle data_shape = c->input(0);
59 ShapeHandle partitions_shape = c->input(1);
60
61 if (!c->RankKnown(partitions_shape)) {
62 return shape_inference::UnknownShape(c);
63 }
64
65 const int64_t rank = c->Rank(partitions_shape);
66
67 // data shape must start with partitions_shape
68 ShapeHandle unused;
69 TF_RETURN_IF_ERROR(
70 c->MergePrefix(data_shape, partitions_shape, &unused, &unused));
71
72 // The partition shape is dynamic in the 0th dimension, and matches
73 // data_shape in the remaining dimensions.
74 ShapeHandle unknown_dim0 = c->MakeShape({c->UnknownDim()});
75
76 ShapeHandle data_suffix_shape;
77 TF_RETURN_IF_ERROR(c->Subshape(data_shape, rank, &data_suffix_shape));
78 ShapeHandle result_shape;
79 TF_RETURN_IF_ERROR(
80 c->Concatenate(unknown_dim0, data_suffix_shape, &result_shape));
81
82 for (int i = 0; i < c->num_outputs(); ++i) {
83 c->set_output(i, result_shape);
84 }
85
86 return OkStatus();
87 });
88
89namespace {
90
91Status DynamicStitchShapeFunction(InferenceContext* c) {
92 int32_t num_partitions;
93 TF_RETURN_IF_ERROR(c->GetAttr("N", &num_partitions));
94
95 bool all_indices_constant = true;
96 int32_t max_index = -1;
97 ShapeHandle extra_shape = c->UnknownShape();
98 for (int i = 0; i < num_partitions; ++i) {
99 const Tensor* indices_t = c->input_tensor(i);
100 if (indices_t == nullptr) {
101 all_indices_constant = false;
102 }
103
104 ShapeHandle indices_shape = c->input(i);
105 ShapeHandle data_shape = c->input(i + num_partitions);
106 if (!c->RankKnown(indices_shape)) {
107 continue;
108 }
109 const int64_t indices_rank = c->Rank(indices_shape);
110
111 // Assert that data_shape starts with indices_shape.
112 ShapeHandle unused;
113 TF_RETURN_IF_ERROR(
114 c->MergePrefix(data_shape, indices_shape, &unused, &unused));
115
116 // The rest belongs to output.
117 ShapeHandle rest;
118 TF_RETURN_IF_ERROR(c->Subshape(data_shape, indices_rank, &rest));
119 TF_RETURN_IF_ERROR(c->Merge(extra_shape, rest, &extra_shape));
120
121 if (indices_t != nullptr) {
122 // The length is based on the highest index from flattened indices.
123 const int32* indices = indices_t->flat<int32>().data();
124 int64_t count = indices_t->NumElements();
125 for (int64_t i = 0; i < count; ++i) {
126 if (indices[i] > max_index) {
127 max_index = indices[i];
128 }
129 }
130 }
131 }
132
133 ShapeHandle output_shape = c->Vector(
134 all_indices_constant ? c->MakeDim(max_index + 1) : c->UnknownDim());
135 TF_RETURN_IF_ERROR(c->Concatenate(output_shape, extra_shape, &output_shape));
136 c->set_output(0, output_shape);
137 return OkStatus();
138}
139
140} // namespace
141
142REGISTER_OP("DynamicStitch")
143 .Input("indices: N * int32")
144 .Input("data: N * T")
145 .Output("merged: T")
146 .Attr("N : int >= 1")
147 .Attr("T : type")
148 .SetShapeFn(DynamicStitchShapeFunction);
149
150REGISTER_OP("ParallelDynamicStitch")
151 .Input("indices: N * int32")
152 .Input("data: N * T")
153 .Output("merged: T")
154 .Attr("N : int >= 1")
155 .Attr("T : type")
156 .SetShapeFn(DynamicStitchShapeFunction);
157
158// --------------------------------------------------------------------------
159
160namespace {
161Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
162 ShapeHandle handle;
163 DimensionHandle unused_handle;
164 for (int i = 0; i < c->num_inputs(); ++i) {
165 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
166 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
167 }
168 for (int i = 0; i < c->num_outputs(); ++i) {
169 c->set_output(i, c->Scalar());
170 }
171 return OkStatus();
172}
173
174Status TwoElementOutput(InferenceContext* c) {
175 c->set_output(0, c->Vector(2));
176 return OkStatus();
177}
178} // namespace
179
180REGISTER_OP("RandomShuffleQueue")
181 .Output("handle: Ref(string)")
182 .Attr("component_types: list(type) >= 1")
183 .Attr("shapes: list(shape) >= 0 = []")
184 .Attr("capacity: int = -1")
185 .Attr("min_after_dequeue: int = 0")
186 .Attr("seed: int = 0")
187 .Attr("seed2: int = 0")
188 .Attr("container: string = ''")
189 .Attr("shared_name: string = ''")
190 .SetIsStateful()
191 .SetShapeFn(TwoElementOutput);
192
193REGISTER_OP("RandomShuffleQueueV2")
194 .Output("handle: resource")
195 .Attr("component_types: list(type) >= 1")
196 .Attr("shapes: list(shape) >= 0 = []")
197 .Attr("capacity: int = -1")
198 .Attr("min_after_dequeue: int = 0")
199 .Attr("seed: int = 0")
200 .Attr("seed2: int = 0")
201 .Attr("container: string = ''")
202 .Attr("shared_name: string = ''")
203 .SetIsStateful()
204 .SetShapeFn(shape_inference::ScalarShape);
205
206REGISTER_OP("FIFOQueue")
207 .Output("handle: Ref(string)")
208 .Attr("component_types: list(type) >= 1")
209 .Attr("shapes: list(shape) >= 0 = []")
210 .Attr("capacity: int = -1")
211 .Attr("container: string = ''")
212 .Attr("shared_name: string = ''")
213 .SetIsStateful()
214 .SetShapeFn(TwoElementOutput);
215
216REGISTER_OP("FIFOQueueV2")
217 .Output("handle: resource")
218 .Attr("component_types: list(type) >= 1")
219 .Attr("shapes: list(shape) >= 0 = []")
220 .Attr("capacity: int = -1")
221 .Attr("container: string = ''")
222 .Attr("shared_name: string = ''")
223 .SetIsStateful()
224 .SetShapeFn(shape_inference::ScalarShape);
225
226REGISTER_OP("PaddingFIFOQueue")
227 .Output("handle: Ref(string)")
228 .Attr("component_types: list(type) >= 1")
229 .Attr("shapes: list(shape) >= 0 = []")
230 .Attr("capacity: int = -1")
231 .Attr("container: string = ''")
232 .Attr("shared_name: string = ''")
233 .SetIsStateful()
234 .SetShapeFn(TwoElementOutput);
235
236REGISTER_OP("PaddingFIFOQueueV2")
237 .Output("handle: resource")
238 .Attr("component_types: list(type) >= 1")
239 .Attr("shapes: list(shape) >= 0 = []")
240 .Attr("capacity: int = -1")
241 .Attr("container: string = ''")
242 .Attr("shared_name: string = ''")
243 .SetIsStateful()
244 .SetShapeFn(shape_inference::ScalarShape);
245
246REGISTER_OP("PriorityQueue")
247 .Output("handle: Ref(string)")
248 .Attr("component_types: list(type) >= 0 = []")
249 .Attr("shapes: list(shape) >= 0")
250 .Attr("capacity: int = -1")
251 .Attr("container: string = ''")
252 .Attr("shared_name: string = ''")
253 .SetIsStateful()
254 .SetShapeFn(TwoElementOutput);
255
256REGISTER_OP("PriorityQueueV2")
257 .Output("handle: resource")
258 .Attr("component_types: list(type) >= 0 = []")
259 .Attr("shapes: list(shape) >= 0")
260 .Attr("capacity: int = -1")
261 .Attr("container: string = ''")
262 .Attr("shared_name: string = ''")
263 .SetIsStateful()
264 .SetShapeFn(shape_inference::ScalarShape);
265
266REGISTER_OP("FakeQueue")
267 .Input("resource: resource")
268 .Output("handle: Ref(string)")
269 .SetIsStateful()
270 .SetShapeFn(TwoElementOutput);
271
272REGISTER_OP("QueueEnqueue")
273 .Input("handle: Ref(string)")
274 .Input("components: Tcomponents")
275 .Attr("Tcomponents: list(type) >= 1")
276 .Attr("timeout_ms: int = -1")
277 .SetShapeFn(shape_inference::UnknownShape);
278
279REGISTER_OP("QueueEnqueueV2")
280 .Input("handle: resource")
281 .Input("components: Tcomponents")
282 .Attr("Tcomponents: list(type) >= 1")
283 .Attr("timeout_ms: int = -1")
284 .SetShapeFn(shape_inference::UnknownShape);
285
286REGISTER_OP("QueueEnqueueMany")
287 .Input("handle: Ref(string)")
288 .Input("components: Tcomponents")
289 .Attr("Tcomponents: list(type) >= 1")
290 .Attr("timeout_ms: int = -1")
291 .SetShapeFn(shape_inference::UnknownShape);
292
293REGISTER_OP("QueueEnqueueManyV2")
294 .Input("handle: resource")
295 .Input("components: Tcomponents")
296 .Attr("Tcomponents: list(type) >= 1")
297 .Attr("timeout_ms: int = -1")
298 .SetShapeFn(shape_inference::UnknownShape);
299
300REGISTER_OP("QueueDequeue")
301 .Input("handle: Ref(string)")
302 .Output("components: component_types")
303 .Attr("component_types: list(type) >= 1")
304 .Attr("timeout_ms: int = -1")
305 .SetShapeFn(shape_inference::UnknownShape);
306
307REGISTER_OP("QueueDequeueV2")
308 .Input("handle: resource")
309 .Output("components: component_types")
310 .Attr("component_types: list(type) >= 1")
311 .Attr("timeout_ms: int = -1")
312 .SetShapeFn([](InferenceContext* c) {
313 auto* t = c->input_handle_shapes_and_types(0);
314 if (t != nullptr && t->size() == c->num_outputs()) {
315 for (int i = 0; i < c->num_outputs(); ++i) {
316 c->set_output(i, (*t)[i].shape);
317 }
318 return OkStatus();
319 } else {
320 return shape_inference::UnknownShape(c);
321 }
322 });
323
324REGISTER_OP("QueueDequeueMany")
325 .Input("handle: Ref(string)")
326 .Input("n: int32")
327 .Output("components: component_types")
328 .Attr("component_types: list(type) >= 1")
329 .Attr("timeout_ms: int = -1")
330 .SetShapeFn(shape_inference::UnknownShape);
331
332REGISTER_OP("QueueDequeueManyV2")
333 .Input("handle: resource")
334 .Input("n: int32")
335 .Output("components: component_types")
336 .Attr("component_types: list(type) >= 1")
337 .Attr("timeout_ms: int = -1")
338 .SetShapeFn([](InferenceContext* c) {
339 ShapeHandle n_shape;
340 if (c->input_tensor(1) == nullptr) {
341 n_shape = c->Vector(InferenceContext::kUnknownDim);
342 } else {
343 const int32_t n = c->input_tensor(1)->scalar<int32>()();
344 if (n < 0) {
345 return errors::InvalidArgument("Input 'n' must be >= 0, but is ", n);
346 }
347 n_shape = c->Vector(n);
348 }
349 return DequeueManyV2Shape(c, n_shape);
350 });
351
352REGISTER_OP("QueueDequeueUpTo")
353 .Input("handle: Ref(string)")
354 .Input("n: int32")
355 .Output("components: component_types")
356 .Attr("component_types: list(type) >= 1")
357 .Attr("timeout_ms: int = -1")
358 .SetShapeFn(shape_inference::UnknownShape);
359
360REGISTER_OP("QueueDequeueUpToV2")
361 .Input("handle: resource")
362 .Input("n: int32")
363 .Output("components: component_types")
364 .Attr("component_types: list(type) >= 1")
365 .Attr("timeout_ms: int = -1")
366 .SetShapeFn([](InferenceContext* c) {
367 return DequeueManyV2Shape(c, c->Vector(InferenceContext::kUnknownDim));
368 });
369
370REGISTER_OP("QueueClose")
371 .Input("handle: Ref(string)")
372 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs)
373 .Attr("cancel_pending_enqueues: bool = false");
374
375REGISTER_OP("QueueCloseV2")
376 .Input("handle: resource")
377 .SetShapeFn(shape_inference::NoOutputs)
378 .Attr("cancel_pending_enqueues: bool = false");
379
380REGISTER_OP("QueueIsClosed")
381 .Input("handle: Ref(string)")
382 .Output("is_closed: bool")
383 .SetShapeFn(shape_inference::ScalarShape);
384
385REGISTER_OP("QueueIsClosedV2")
386 .Input("handle: resource")
387 .Output("is_closed: bool")
388 .SetShapeFn(shape_inference::ScalarShape);
389
390REGISTER_OP("QueueSize")
391 .Input("handle: Ref(string)")
392 .Output("size: int32")
393 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs);
394
395REGISTER_OP("QueueSizeV2")
396 .Input("handle: resource")
397 .Output("size: int32")
398 .SetShapeFn(shape_inference::UnchangedShape);
399
400// --------------------------------------------------------------------------
401
402REGISTER_OP("AccumulatorNumAccumulated")
403 .Input("handle: Ref(string)")
404 .Output("num_accumulated: int32")
405 .SetShapeFn(shape_inference::ScalarShape);
406
407REGISTER_OP("AccumulatorSetGlobalStep")
408 .Input("handle: Ref(string)")
409 .Input("new_global_step: int64")
410 .SetShapeFn([](InferenceContext* c) {
411 ShapeHandle unused;
412 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
413 return OkStatus();
414 });
415
416REGISTER_OP("ConditionalAccumulator")
417 .Output("handle: Ref(string)")
418 .Attr("dtype: numbertype")
419 .Attr("shape: shape")
420 .Attr("container: string = ''")
421 .Attr("shared_name: string = ''")
422 .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
423 .SetIsStateful()
424 .SetShapeFn([](InferenceContext* c) {
425 c->set_output(0, c->Vector(2));
426 return OkStatus();
427 });
428
429REGISTER_OP("AccumulatorApplyGradient")
430 .Input("handle: Ref(string)")
431 .Input("local_step: int64")
432 .Input("gradient: dtype")
433 .Attr("dtype: numbertype")
434 .SetShapeFn([](InferenceContext* c) {
435 ShapeHandle unused;
436 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
437 return OkStatus();
438 });
439
440REGISTER_OP("AccumulatorTakeGradient")
441 .Input("handle: Ref(string)")
442 .Input("num_required: int32")
443 .Output("average: dtype")
444 .SetShapeFn([](InferenceContext* c) {
445 ShapeHandle unused;
446 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
447 // Shape of output is the shape of the accumulator referenced
448 // by 'handle', but which is not available here, so we lose
449 // shape information.
450 return shape_inference::UnknownShape(c);
451 })
452 .Attr("dtype: numbertype");
453
454// -----------------V2 accumulators that use resource -------------------------
455
456REGISTER_OP("ResourceAccumulatorNumAccumulated")
457 .Input("handle: resource")
458 .Output("num_accumulated: int32")
459 .SetShapeFn(shape_inference::ScalarShape);
460
461REGISTER_OP("ResourceAccumulatorSetGlobalStep")
462 .Input("handle: resource")
463 .Input("new_global_step: int64")
464 .SetShapeFn([](InferenceContext* c) {
465 ShapeHandle unused;
466 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
467 return OkStatus();
468 });
469
470REGISTER_OP("ResourceConditionalAccumulator")
471 .Output("handle: resource")
472 .Attr("dtype: numbertype")
473 .Attr("shape: shape")
474 .Attr("container: string = ''")
475 .Attr("shared_name: string = ''")
476 .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
477 .SetIsStateful()
478 .SetShapeFn([](InferenceContext* c) {
479 c->set_output(0, c->Vector(2));
480 return OkStatus();
481 });
482
483REGISTER_OP("ResourceAccumulatorApplyGradient")
484 .Input("handle: resource")
485 .Input("local_step: int64")
486 .Input("gradient: dtype")
487 .Attr("dtype: numbertype")
488 .SetShapeFn([](InferenceContext* c) {
489 ShapeHandle unused;
490 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
491 return OkStatus();
492 });
493
494REGISTER_OP("ResourceAccumulatorTakeGradient")
495 .Input("handle: resource")
496 .Input("num_required: int32")
497 .Output("average: dtype")
498 .SetShapeFn([](InferenceContext* c) {
499 ShapeHandle unused;
500 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
501 // Shape of output is the shape of the accumulator referenced
502 // by 'handle', but which is not available here, so we lose
503 // shape information.
504 return shape_inference::UnknownShape(c);
505 })
506 .Attr("dtype: numbertype");
507
508// TODO(nponomareva): change these all to use resources.
509REGISTER_OP("SparseConditionalAccumulator")
510 .Output("handle: Ref(string)")
511 .Attr("dtype: numbertype")
512 .Attr("shape: shape")
513 .Attr("container: string = ''")
514 .Attr("shared_name: string = ''")
515 .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
516 .SetIsStateful()
517 .SetShapeFn([](InferenceContext* c) {
518 c->set_output(0, c->Vector(2));
519 return OkStatus();
520 });
521
522REGISTER_OP("SparseAccumulatorApplyGradient")
523 .Input("handle: Ref(string)")
524 .Input("local_step: int64")
525 .Input("gradient_indices: int64")
526 .Input("gradient_values: dtype")
527 .Input("gradient_shape: int64")
528 .Attr("dtype: numbertype")
529 .Attr("has_known_shape: bool")
530 .SetShapeFn([](InferenceContext* c) {
531 ShapeHandle unused;
532 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
533 return OkStatus();
534 });
535
536REGISTER_OP("SparseAccumulatorTakeGradient")
537 .Input("handle: Ref(string)")
538 .Input("num_required: int32")
539 .Output("indices: int64")
540 .Output("values: dtype")
541 .Output("shape: int64")
542 .Attr("dtype: numbertype")
543 .SetShapeFn([](InferenceContext* c) {
544 ShapeHandle unused;
545 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
546 // Shape of output is the shape of the accumulator referenced
547 // by 'handle', but which is not available here, so we lose
548 // shape information.
549 return shape_inference::UnknownShape(c);
550 });
551
552// --------------------------------------------------------------------------
553
554REGISTER_OP("StackV2")
555 .Input("max_size: int32")
556 .Output("handle: resource")
557 .Attr("elem_type: type")
558 .Attr("stack_name: string = ''")
559 .SetIsStateful()
560 .SetShapeFn(TwoElementOutput);
561
562REGISTER_OP("StackPushV2")
563 .Input("handle: resource")
564 .Input("elem: T")
565 .Output("output: T")
566 .Attr("T: type")
567 .Attr("swap_memory: bool = false")
568 .SetShapeFn([](shape_inference::InferenceContext* c) {
569 c->set_output(0, c->input(1));
570 return OkStatus();
571 });
572
573REGISTER_OP("StackPopV2")
574 .Input("handle: resource")
575 .Output("elem: elem_type")
576 .Attr("elem_type: type")
577 .SetShapeFn(shape_inference::UnknownShape);
578
579REGISTER_OP("StackCloseV2")
580 .Input("handle: resource")
581 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs);
582
583// Deprecated ref-typed variants of stack.
584
585REGISTER_OP("Stack")
586 .Output("handle: Ref(string)")
587 .Attr("elem_type: type")
588 .Attr("stack_name: string = ''")
589 .SetIsStateful()
590 .SetShapeFn(TwoElementOutput);
591
592REGISTER_OP("StackPush")
593 .Input("handle: Ref(string)")
594 .Input("elem: T")
595 .Output("output: T")
596 .Attr("T: type")
597 .Attr("swap_memory: bool = false")
598 .SetShapeFn([](shape_inference::InferenceContext* c) {
599 c->set_output(0, c->input(1));
600 return OkStatus();
601 });
602
603REGISTER_OP("StackPop")
604 .Input("handle: Ref(string)")
605 .Output("elem: elem_type")
606 .Attr("elem_type: type")
607 .SetShapeFn(shape_inference::UnknownShape);
608
609REGISTER_OP("StackClose")
610 .Input("handle: Ref(string)")
611 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs);
612
613// --------------------------------------------------------------------------
614
615REGISTER_OP("TensorArrayV3")
616 .Input("size: int32")
617 .Attr("dtype: type")
618 .Attr("element_shape: shape = { unknown_rank: true }")
619 .Attr("dynamic_size: bool = false")
620 .Attr("clear_after_read: bool = true")
621 .Attr("identical_element_shapes: bool = false")
622 .Attr("tensor_array_name: string = ''")
623 .Output("handle: resource")
624 .Output("flow: float")
625 .SetIsStateful()
626 .SetShapeFn([](InferenceContext* c) {
627 ShapeHandle unused;
628 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
629 c->set_output(0, c->Vector(2));
630 c->set_output(1, c->Scalar());
631 bool identical_shapes;
632 TF_RETURN_IF_ERROR(
633 c->GetAttr("identical_element_shapes", &identical_shapes));
634 DataType t;
635 TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
636 PartialTensorShape p;
637 TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &p));
638 ShapeHandle s;
639 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
640 if (c->FullyDefined(s) || identical_shapes) {
641 c->set_output_handle_shapes_and_types(
642 0, std::vector<shape_inference::ShapeAndType>{{s, t}});
643 }
644 return OkStatus();
645 });
646
647REGISTER_OP("TensorArrayGradV3")
648 .Input("handle: resource")
649 .Input("flow_in: float")
650 .Output("grad_handle: resource")
651 .Output("flow_out: float")
652 .Attr("source: string")
653 .SetIsStateful()
654 .SetShapeFn([](InferenceContext* c) {
655 ShapeHandle handle;
656 DimensionHandle unused_dim;
657 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
658 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
659 c->set_output(0, c->Vector(2));
660 c->set_output(1, c->Scalar());
661 if (c->input_handle_shapes_and_types(0)) {
662 c->set_output_handle_shapes_and_types(
663 0, *c->input_handle_shapes_and_types(0));
664 }
665 return OkStatus();
666 });
667
668REGISTER_OP("TensorArrayGradWithShape")
669 .Input("handle: resource")
670 .Input("flow_in: float")
671 .Input("shape_to_prepend: int32")
672 .Output("grad_handle: resource")
673 .Output("flow_out: float")
674 .Attr("source: string")
675 .SetIsStateful()
676 .SetShapeFn([](InferenceContext* c) {
677 ShapeHandle handle;
678 DimensionHandle unused_dim;
679 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
680 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
681 c->set_output(0, c->Vector(2));
682 c->set_output(1, c->Scalar());
683 auto* shape_and_type = c->input_handle_shapes_and_types(0);
684 if (shape_and_type) {
685 auto input_shape = (*shape_and_type)[0].shape;
686 auto dtype = (*shape_and_type)[0].dtype;
687 // Note that shape_to_preped is a rank 1 Tensor representing a shape.
688 // The size of dimension 0 is the number of dimensions we need to add to
689 // output shape.
690 int64_t prepend_rank = c->Value(c->Dim(c->input(2), 0));
691 if (c->RankKnown(input_shape) &&
692 prepend_rank != InferenceContext::kUnknownDim) {
693 int32_t input_rank = c->Rank(input_shape);
694 std::vector<DimensionHandle> dims;
695 dims.reserve(prepend_rank + input_rank);
696 for (int i = 0; i < prepend_rank; ++i) {
697 dims.push_back(c->UnknownDim());
698 }
699 for (int i = 0; i < input_rank; ++i) {
700 dims.push_back(c->Dim(input_shape, i));
701 }
702 c->set_output_handle_shapes_and_types(0,
703 {{c->MakeShape(dims), dtype}});
704 } else {
705 c->set_output_handle_shapes_and_types(0,
706 {{c->UnknownShape(), dtype}});
707 }
708 }
709 return OkStatus();
710 });
711
712REGISTER_OP("TensorArrayWriteV3")
713 .Input("handle: resource")
714 .Input("index: int32")
715 .Input("value: T")
716 .Input("flow_in: float")
717 .Output("flow_out: float")
718 .Attr("T: type")
719 .SetShapeFn([](InferenceContext* c) {
720 ShapeHandle handle;
721 DimensionHandle unused_dim;
722 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
723 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
724
725 ShapeHandle unused;
726 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
727 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
728
729 auto* handle_data = c->input_handle_shapes_and_types(0);
730 if (handle_data != nullptr && !handle_data->empty()) {
731 shape_inference::ShapeAndType shape_and_type = (*handle_data)[0];
732 ShapeHandle value_shape = c->input(2);
733 TF_RETURN_IF_ERROR(
734 c->Merge(shape_and_type.shape, value_shape, &unused));
735 }
736
737 return shape_inference::ScalarShape(c);
738 });
739
740REGISTER_OP("TensorArrayReadV3")
741 .Input("handle: resource")
742 .Input("index: int32")
743 .Input("flow_in: float")
744 .Output("value: dtype")
745 .Attr("dtype: type")
746 .SetShapeFn([](InferenceContext* c) {
747 ShapeHandle handle;
748 DimensionHandle unused_dim;
749 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
750 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
751 ShapeHandle unused;
752 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
753 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
754 auto shapes = c->input_handle_shapes_and_types(0);
755 if (shapes != nullptr && !shapes->empty()) {
756 ShapeHandle tensor_shape = shapes->at(0).shape;
757 c->set_output(0, tensor_shape);
758 return OkStatus();
759 } else {
760 return shape_inference::UnknownShape(c);
761 }
762 });
763
764REGISTER_OP("TensorArrayGatherV3")
765 .Input("handle: resource")
766 .Input("indices: int32")
767 .Input("flow_in: float")
768 .Output("value: dtype")
769 .Attr("dtype: type")
770 .Attr("element_shape: shape = { unknown_rank: true }")
771 .SetShapeFn([](InferenceContext* c) {
772 ShapeHandle indices;
773 ShapeHandle unused;
774 DimensionHandle unused_dim;
775 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
776 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices));
777 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
778 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
779 auto shapes = c->input_handle_shapes_and_types(0);
780 if (shapes != nullptr && !shapes->empty()) {
781 ShapeHandle tensor_shape = shapes->at(0).shape;
782 ShapeHandle output_shape;
783 TF_RETURN_IF_ERROR(
784 c->Concatenate(indices, tensor_shape, &output_shape));
785 c->set_output(0, output_shape);
786 return OkStatus();
787 } else {
788 PartialTensorShape p;
789 TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &p));
790 ShapeHandle s;
791 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
792 ShapeHandle output_shape;
793 TF_RETURN_IF_ERROR(c->Concatenate(indices, s, &output_shape));
794 c->set_output(0, output_shape);
795 return OkStatus();
796 }
797 });
798
799REGISTER_OP("TensorArrayScatterV3")
800 .Input("handle: resource")
801 .Input("indices: int32")
802 .Input("value: T")
803 .Input("flow_in: float")
804 .Output("flow_out: float")
805 .Attr("T: type")
806 .SetShapeFn([](InferenceContext* c) {
807 ShapeHandle indices;
808 ShapeHandle unused;
809 DimensionHandle unused_dim;
810 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
811 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices));
812 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
813 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
814 ShapeHandle value_shape;
815 // Assert that the length of the indices tensor is equal to the first
816 // dimension of the value tensor.
817 TF_RETURN_IF_ERROR(
818 c->MergePrefix(c->input(2), indices, &value_shape, &indices));
819 auto shapes = c->input_handle_shapes_and_types(0);
820 if (shapes != nullptr && !shapes->empty()) {
821 ShapeHandle tensor_shape = shapes->at(0).shape;
822 ShapeHandle fed_shape;
823 TF_RETURN_IF_ERROR(c->Subshape(value_shape, 1, &fed_shape));
824 TF_RETURN_IF_ERROR(c->Merge(tensor_shape, fed_shape, &fed_shape));
825 }
826 return shape_inference::ScalarShape(c);
827 });
828
829REGISTER_OP("TensorArrayConcatV3")
830 .Input("handle: resource")
831 .Input("flow_in: float")
832 .Output("value: dtype")
833 .Output("lengths: int64")
834 .Attr("dtype: type")
835 .Attr("element_shape_except0: shape = { unknown_rank: true }")
836 .SetShapeFn([](InferenceContext* c) {
837 ShapeHandle handle;
838 DimensionHandle unused_dim;
839 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
840 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
841 ShapeHandle unused;
842 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
843 c->set_output(0, c->UnknownShape());
844 c->set_output(1, c->Vector(c->UnknownDim()));
845 return OkStatus();
846 });
847
848REGISTER_OP("TensorArraySplitV3")
849 .Input("handle: resource")
850 .Input("value: T")
851 .Input("lengths: int64")
852 .Input("flow_in: float")
853 .Output("flow_out: float")
854 .Attr("T: type")
855 .SetShapeFn([](InferenceContext* c) {
856 ShapeHandle handle;
857 DimensionHandle unused_dim;
858 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
859 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
860 ShapeHandle unused;
861 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
862 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
863 return shape_inference::ScalarShape(c);
864 });
865
866REGISTER_OP("TensorArraySizeV3")
867 .Input("handle: resource")
868 .Input("flow_in: float")
869 .Output("size: int32")
870 .SetShapeFn([](InferenceContext* c) {
871 ShapeHandle handle;
872 DimensionHandle unused_dim;
873 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
874 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
875 return shape_inference::ScalarShape(c);
876 });
877
878REGISTER_OP("TensorArrayCloseV3")
879 .Input("handle: resource")
880 .SetShapeFn([](InferenceContext* c) {
881 ShapeHandle handle;
882 DimensionHandle unused_dim;
883 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
884 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
885 return OkStatus();
886 });
887
888// --------------------------------------------------------------------------
889
890// Deprecated TensorArray methods
891
892REGISTER_OP("TensorArray")
893 .Input("size: int32")
894 .Attr("dtype: type")
895 .Attr("dynamic_size: bool = false")
896 .Attr("clear_after_read: bool = true")
897 .Attr("tensor_array_name: string = ''")
898 .Attr("element_shape: shape = { unknown_rank: true }")
899 .Output("handle: Ref(string)")
900 .SetIsStateful()
901 .SetShapeFn(shape_inference::UnknownShape)
902 .Deprecated(16, "Use TensorArrayV3");
903REGISTER_OP("TensorArrayV2")
904 .Input("size: int32")
905 .Attr("dtype: type")
906 .Attr("element_shape: shape = { unknown_rank: true }")
907 .Attr("dynamic_size: bool = false")
908 .Attr("clear_after_read: bool = true")
909 .Attr("tensor_array_name: string = ''")
910 .Output("handle: string")
911 .SetIsStateful()
912 .SetShapeFn([](InferenceContext* c) {
913 ShapeHandle unused;
914 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
915 c->set_output(0, c->Vector(2));
916 return OkStatus();
917 })
918 .Deprecated(26, "Use TensorArrayV3");
919REGISTER_OP("TensorArrayGrad")
920 .Input("handle: string")
921 .Input("flow_in: float")
922 .Output("grad_handle: Ref(string)")
923 .Attr("source: string")
924 .SetIsStateful()
925 .SetShapeFn(shape_inference::UnknownShape)
926 .Deprecated(16, "Use TensorArrayGradV3");
927REGISTER_OP("TensorArrayGradV2")
928 .Input("handle: string")
929 .Input("flow_in: float")
930 .Output("grad_handle: string")
931 .Attr("source: string")
932 .SetIsStateful()
933 .SetShapeFn([](InferenceContext* c) {
934 ShapeHandle handle;
935 DimensionHandle unused_dim;
936 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
937 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
938 c->set_output(0, c->Vector(2));
939 return OkStatus();
940 })
941 .Deprecated(26, "Use TensorArrayGradV3");
942REGISTER_OP("TensorArrayWrite")
943 .Input("handle: Ref(string)")
944 .Input("index: int32")
945 .Input("value: T")
946 .Input("flow_in: float")
947 .Output("flow_out: float")
948 .Attr("T: type")
949 .SetShapeFn(shape_inference::UnknownShape)
950 .Deprecated(16, "Use TensorArrayWriteV3");
951REGISTER_OP("TensorArrayWriteV2")
952 .Input("handle: string")
953 .Input("index: int32")
954 .Input("value: T")
955 .Input("flow_in: float")
956 .Output("flow_out: float")
957 .Attr("T: type")
958 .SetShapeFn([](InferenceContext* c) {
959 ShapeHandle handle;
960 DimensionHandle unused_dim;
961 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
962 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
963
964 ShapeHandle unused;
965 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
966 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
967 return shape_inference::ScalarShape(c);
968 })
969 .Deprecated(26, "Use TensorArrayWriteV3");
970REGISTER_OP("TensorArrayRead")
971 .Input("handle: Ref(string)")
972 .Input("index: int32")
973 .Input("flow_in: float")
974 .Output("value: dtype")
975 .Attr("dtype: type")
976 .SetShapeFn(shape_inference::UnknownShape)
977 .Deprecated(16, "Use TensorArrayReadV3");
978REGISTER_OP("TensorArrayReadV2")
979 .Input("handle: string")
980 .Input("index: int32")
981 .Input("flow_in: float")
982 .Output("value: dtype")
983 .Attr("dtype: type")
984 .SetShapeFn([](InferenceContext* c) {
985 ShapeHandle handle;
986 DimensionHandle unused_dim;
987 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
988 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
989 ShapeHandle unused;
990 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
991 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
992 return shape_inference::UnknownShape(c);
993 })
994 .Deprecated(26, "Use TensorArrayReadV3");
995REGISTER_OP("TensorArrayPack")
996 .Input("handle: Ref(string)")
997 .Input("flow_in: float")
998 .Output("value: dtype")
999 .Attr("dtype: type")
1000 .Attr("element_shape: shape = { unknown_rank: true }")
1001 .SetShapeFn(shape_inference::UnknownShape)
1002 .Deprecated(16, "Use TensorArrayGatherV3 with RangeOp");
1003REGISTER_OP("TensorArrayUnpack")
1004 .Input("handle: Ref(string)")
1005 .Input("value: T")
1006 .Input("flow_in: float")
1007 .Output("flow_out: float")
1008 .Attr("T: type")
1009 .SetShapeFn(shape_inference::UnknownShape)
1010 .Deprecated(20, "Use TensorArrayScatterV3 with RangeOp");
1011REGISTER_OP("TensorArrayGather")
1012 .Input("handle: Ref(string)")
1013 .Input("indices: int32")
1014 .Input("flow_in: float")
1015 .Output("value: dtype")
1016 .Attr("dtype: type")
1017 .Attr("element_shape: shape = { unknown_rank: true }")
1018 .SetShapeFn(shape_inference::UnknownShape)
1019 .Deprecated(16, "Use TensorArrayGatherV3");
1020REGISTER_OP("TensorArrayGatherV2")
1021 .Input("handle: string")
1022 .Input("indices: int32")
1023 .Input("flow_in: float")
1024 .Output("value: dtype")
1025 .Attr("dtype: type")
1026 .Attr("element_shape: shape = { unknown_rank: true }")
1027 .SetShapeFn([](InferenceContext* c) {
1028 ShapeHandle unused;
1029 DimensionHandle unused_dim;
1030 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1031 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
1032 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
1033 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1034 return shape_inference::UnknownShape(c);
1035 })
1036 .Deprecated(26, "Use TensorArrayGatherV3");
1037REGISTER_OP("TensorArrayScatter")
1038 .Input("handle: Ref(string)")
1039 .Input("indices: int32")
1040 .Input("value: T")
1041 .Input("flow_in: float")
1042 .Output("flow_out: float")
1043 .Attr("T: type")
1044 .SetShapeFn(shape_inference::UnknownShape)
1045 .Deprecated(19, "Use TensorArrayGradV3");
1046REGISTER_OP("TensorArrayScatterV2")
1047 .Input("handle: string")
1048 .Input("indices: int32")
1049 .Input("value: T")
1050 .Input("flow_in: float")
1051 .Output("flow_out: float")
1052 .Attr("T: type")
1053 .SetShapeFn([](InferenceContext* c) {
1054 ShapeHandle unused;
1055 DimensionHandle unused_dim;
1056 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
1057 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
1058 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
1059 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1060 return shape_inference::ScalarShape(c);
1061 })
1062 .Deprecated(26, "Use TensorArrayScatterV3");
1063REGISTER_OP("TensorArrayConcat")
1064 .Input("handle: Ref(string)")
1065 .Input("flow_in: float")
1066 .Output("value: dtype")
1067 .Output("lengths: int64")
1068 .Attr("dtype: type")
1069 .Attr("element_shape_except0: shape = { unknown_rank: true }")
1070 .SetShapeFn(shape_inference::UnknownShape)
1071 .Deprecated(16, "Use TensorArrayGradV3");
1072REGISTER_OP("TensorArrayConcatV2")
1073 .Input("handle: string")
1074 .Input("flow_in: float")
1075 .Output("value: dtype")
1076 .Output("lengths: int64")
1077 .Attr("dtype: type")
1078 .Attr("element_shape_except0: shape = { unknown_rank: true }")
1079 .SetShapeFn([](InferenceContext* c) {
1080 ShapeHandle handle;
1081 DimensionHandle unused_dim;
1082 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
1083 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
1084 ShapeHandle unused;
1085 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1086 c->set_output(0, c->UnknownShape());
1087 c->set_output(1, c->Vector(c->UnknownDim()));
1088 return OkStatus();
1089 });
1090REGISTER_OP("TensorArraySplit")
1091 .Input("handle: Ref(string)")
1092 .Input("value: T")
1093 .Input("lengths: int64")
1094 .Input("flow_in: float")
1095 .Output("flow_out: float")
1096 .Attr("T: type")
1097 .SetShapeFn(shape_inference::UnknownShape)
1098 .Deprecated(16, "Use TensorArraySplitV3");
1099REGISTER_OP("TensorArraySplitV2")
1100 .Input("handle: string")
1101 .Input("value: T")
1102 .Input("lengths: int64")
1103 .Input("flow_in: float")
1104 .Output("flow_out: float")
1105 .Attr("T: type")
1106 .SetShapeFn([](InferenceContext* c) {
1107 ShapeHandle handle;
1108 DimensionHandle unused_dim;
1109 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
1110 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
1111 ShapeHandle unused;
1112 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1113 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1114 return shape_inference::ScalarShape(c);
1115 })
1116 .Deprecated(26, "Use TensorArraySplitV3");
1117REGISTER_OP("TensorArraySize")
1118 .Input("handle: Ref(string)")
1119 .Input("flow_in: float")
1120 .Output("size: int32")
1121 .SetShapeFn(shape_inference::UnknownShape)
1122 .Deprecated(16, "Use TensorArraySizeV3");
1123REGISTER_OP("TensorArraySizeV2")
1124 .Input("handle: string")
1125 .Input("flow_in: float")
1126 .Output("size: int32")
1127 .SetShapeFn([](InferenceContext* c) {
1128 ShapeHandle handle;
1129 DimensionHandle unused_dim;
1130 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
1131 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
1132 return shape_inference::ScalarShape(c);
1133 })
1134 .Deprecated(26, "Use TensorArraySizeV3");
1135REGISTER_OP("TensorArrayClose")
1136 .Input("handle: Ref(string)")
1137 .SetShapeFn([](InferenceContext* c) { return OkStatus(); })
1138 .Deprecated(16, "Use TensorArrayCloseV3");
1139REGISTER_OP("TensorArrayCloseV2")
1140 .Input("handle: string")
1141 .SetShapeFn([](InferenceContext* c) {
1142 ShapeHandle handle;
1143 DimensionHandle unused_dim;
1144 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
1145 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
1146 return OkStatus();
1147 })
1148 .Deprecated(26, "Use TensorArrayCloseV3");
1149
1150// --------------------------------------------------------------------------
1151
1152REGISTER_OP("Barrier")
1153 .SetIsStateful()
1154 .Output("handle: Ref(string)")
1155 .Attr("component_types: list(type) >= 1")
1156 .Attr("shapes: list(shape) >= 0 = []")
1157 .Attr("capacity: int = -1")
1158 .Attr("container: string = ''")
1159 .Attr("shared_name: string = ''")
1160 .SetShapeFn(TwoElementOutput);
1161
1162REGISTER_OP("BarrierInsertMany")
1163 .Input("handle: Ref(string)")
1164 .Input("keys: string")
1165 .Input("values: T")
1166 .Attr("T: type")
1167 .Attr("component_index: int")
1168 .SetShapeFn([](InferenceContext* c) {
1169 ShapeHandle keys = c->input(1);
1170 ShapeHandle values = c->input(2);
1171 ShapeHandle handle;
1172 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
1173 DimensionHandle unused_dim;
1174 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
1175 TF_RETURN_IF_ERROR(c->WithRank(keys, 1, &keys));
1176 TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
1177 TF_RETURN_IF_ERROR(c->Merge(keys, c->Vector(c->Dim(values, 0)), &handle));
1178 return OkStatus();
1179 });
1180
1181REGISTER_OP("BarrierTakeMany")
1182 .Input("handle: Ref(string)")
1183 .Input("num_elements: int32")
1184 .Output("indices: int64")
1185 .Output("keys: string")
1186 .Output("values: component_types")
1187 .Attr("component_types: list(type) >= 1")
1188 .Attr("allow_small_batch: bool = false")
1189 .Attr("wait_for_incomplete: bool = false")
1190 .Attr("timeout_ms: int = -1")
1191 .SetShapeFn(shape_inference::UnknownShape);
1192
1193REGISTER_OP("BarrierClose")
1194 .Input("handle: Ref(string)")
1195 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs)
1196 .Attr("cancel_pending_enqueues: bool = false");
1197
1198REGISTER_OP("BarrierReadySize")
1199 .Input("handle: Ref(string)")
1200 .Output("size: int32")
1201 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs);
1202
1203REGISTER_OP("BarrierIncompleteSize")
1204 .Input("handle: Ref(string)")
1205 .Output("size: int32")
1206 .SetShapeFn(TwoElementVectorInputsAndScalarOutputs);
1207
1208// --------------------------------------------------------------------------
1209
1210REGISTER_OP("GetSessionHandle")
1211 .Input("value: T")
1212 .Output("handle: string")
1213 .Attr("T: type")
1214 .SetIsStateful()
1215 .SetShapeFn(shape_inference::ScalarShape);
1216
1217REGISTER_OP("GetSessionHandleV2")
1218 .Input("value: T")
1219 .Output("handle: resource")
1220 .Attr("T: type")
1221 .SetIsStateful()
1222 .SetShapeFn(shape_inference::ScalarShape);
1223
1224REGISTER_OP("GetSessionTensor")
1225 .Input("handle: string")
1226 .Output("value: dtype")
1227 .Attr("dtype: type")
1228 .SetIsStateful()
1229 .SetShapeFn([](InferenceContext* c) {
1230 ShapeHandle unused;
1231 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
1232 return shape_inference::UnknownShape(c);
1233 });
1234
1235REGISTER_OP("DeleteSessionTensor")
1236 .Input("handle: string")
1237 .SetIsStateful()
1238 .SetShapeFn([](InferenceContext* c) {
1239 ShapeHandle unused;
1240 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
1241 return OkStatus();
1242 });
1243
1244REGISTER_OP("Stage")
1245 .Input("values: dtypes")
1246 .Attr("capacity: int >= 0 = 0")
1247 .Attr("memory_limit: int >= 0 = 0")
1248 .Attr("dtypes: list(type)")
1249 .Attr("container: string = ''")
1250 .Attr("shared_name: string = ''")
1251 .SetShapeFn(shape_inference::UnknownShape)
1252 .SetIsStateful();
1253
1254REGISTER_OP("Unstage")
1255 .Output("values: dtypes")
1256 .Attr("capacity: int >= 0 = 0")
1257 .Attr("memory_limit: int >= 0 = 0")
1258 .Attr("dtypes: list(type)")
1259 .Attr("container: string = ''")
1260 .Attr("shared_name: string = ''")
1261 .SetShapeFn(shape_inference::UnknownShape)
1262 .SetIsStateful();
1263
1264REGISTER_OP("StagePeek")
1265 .Input("index: int32")
1266 .Output("values: dtypes")
1267 .Attr("capacity: int >= 0 = 0")
1268 .Attr("memory_limit: int >= 0 = 0")
1269 .Attr("dtypes: list(type)")
1270 .Attr("container: string = ''")
1271 .Attr("shared_name: string = ''")
1272 .SetShapeFn(shape_inference::UnknownShape)
1273 .SetIsStateful();
1274
1275REGISTER_OP("StageSize")
1276 .Output("size: int32")
1277 .Attr("capacity: int >= 0 = 0")
1278 .Attr("memory_limit: int >= 0 = 0")
1279 .Attr("dtypes: list(type)")
1280 .Attr("container: string = ''")
1281 .Attr("shared_name: string = ''")
1282 .SetShapeFn(shape_inference::ScalarShape)
1283 .SetIsStateful();
1284
1285REGISTER_OP("StageClear")
1286 .Attr("capacity: int >= 0 = 0")
1287 .Attr("memory_limit: int >= 0 = 0")
1288 .Attr("dtypes: list(type)")
1289 .Attr("container: string = ''")
1290 .Attr("shared_name: string = ''")
1291 .SetShapeFn(shape_inference::UnknownShape)
1292 .SetIsStateful();
1293
1294// UnorderedMap
1295REGISTER_OP("MapStage")
1296 .Input("key: int64")
1297 .Input("indices: int32")
1298 .Input("values: fake_dtypes")
1299 .Attr("capacity: int >= 0 = 0")
1300 .Attr("memory_limit: int >= 0 = 0")
1301 .Attr("dtypes: list(type)")
1302 .Attr("fake_dtypes: list(type)")
1303 .Attr("container: string = ''")
1304 .Attr("shared_name: string = ''")
1305 .SetShapeFn(tensorflow::shape_inference::NoOutputs)
1306 .SetIsStateful();
1307
1308REGISTER_OP("MapPeek")
1309 .Input("key: int64")
1310 .Input("indices: int32")
1311 .Output("values: dtypes")
1312 .Attr("capacity: int >= 0 = 0")
1313 .Attr("memory_limit: int >= 0 = 0")
1314 .Attr("dtypes: list(type)")
1315 .Attr("container: string = ''")
1316 .Attr("shared_name: string = ''")
1317 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1318 .SetIsStateful();
1319
1320REGISTER_OP("MapUnstage")
1321 .Input("key: int64")
1322 .Input("indices: int32")
1323 .Output("values: dtypes")
1324 .Attr("capacity: int >= 0 = 0")
1325 .Attr("memory_limit: int >= 0 = 0")
1326 .Attr("dtypes: list(type)")
1327 .Attr("container: string = ''")
1328 .Attr("shared_name: string = ''")
1329 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1330 .SetIsStateful();
1331
1332REGISTER_OP("MapUnstageNoKey")
1333 .Input("indices: int32")
1334 .Output("key: int64")
1335 .Output("values: dtypes")
1336 .Attr("capacity: int >= 0 = 0")
1337 .Attr("memory_limit: int >= 0 = 0")
1338 .Attr("dtypes: list(type)")
1339 .Attr("container: string = ''")
1340 .Attr("shared_name: string = ''")
1341 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1342 .SetIsStateful();
1343
1344REGISTER_OP("MapSize")
1345 .Output("size: int32")
1346 .Attr("capacity: int >= 0 = 0")
1347 .Attr("memory_limit: int >= 0 = 0")
1348 .Attr("dtypes: list(type)")
1349 .Attr("container: string = ''")
1350 .Attr("shared_name: string = ''")
1351 .SetShapeFn(tensorflow::shape_inference::ScalarShape)
1352 .SetIsStateful();
1353
1354REGISTER_OP("MapIncompleteSize")
1355 .Output("size: int32")
1356 .Attr("capacity: int >= 0 = 0")
1357 .Attr("memory_limit: int >= 0 = 0")
1358 .Attr("dtypes: list(type)")
1359 .Attr("container: string = ''")
1360 .Attr("shared_name: string = ''")
1361 .SetShapeFn(tensorflow::shape_inference::ScalarShape)
1362 .SetIsStateful();
1363
1364REGISTER_OP("MapClear")
1365 .Attr("capacity: int >= 0 = 0")
1366 .Attr("memory_limit: int >= 0 = 0")
1367 .Attr("dtypes: list(type)")
1368 .Attr("container: string = ''")
1369 .Attr("shared_name: string = ''")
1370 .SetShapeFn(tensorflow::shape_inference::NoOutputs)
1371 .SetIsStateful();
1372
1373// OrderedMap
1374REGISTER_OP("OrderedMapStage")
1375 .Input("key: int64")
1376 .Input("indices: int32")
1377 .Input("values: fake_dtypes")
1378 .Attr("capacity: int >= 0 = 0")
1379 .Attr("memory_limit: int >= 0 = 0")
1380 .Attr("dtypes: list(type)")
1381 .Attr("fake_dtypes: list(type)")
1382 .Attr("container: string = ''")
1383 .Attr("shared_name: string = ''")
1384 .SetShapeFn(tensorflow::shape_inference::NoOutputs)
1385 .SetIsStateful();
1386
1387REGISTER_OP("OrderedMapPeek")
1388 .Input("key: int64")
1389 .Input("indices: int32")
1390 .Output("values: dtypes")
1391 .Attr("capacity: int >= 0 = 0")
1392 .Attr("memory_limit: int >= 0 = 0")
1393 .Attr("dtypes: list(type)")
1394 .Attr("container: string = ''")
1395 .Attr("shared_name: string = ''")
1396 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1397 .SetIsStateful();
1398
1399REGISTER_OP("OrderedMapUnstage")
1400 .Input("key: int64")
1401 .Input("indices: int32")
1402 .Output("values: dtypes")
1403 .Attr("capacity: int >= 0 = 0")
1404 .Attr("memory_limit: int >= 0 = 0")
1405 .Attr("dtypes: list(type)")
1406 .Attr("container: string = ''")
1407 .Attr("shared_name: string = ''")
1408 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1409 .SetIsStateful();
1410
1411REGISTER_OP("OrderedMapUnstageNoKey")
1412 .Input("indices: int32")
1413 .Output("key: int64")
1414 .Output("values: dtypes")
1415 .Attr("capacity: int >= 0 = 0")
1416 .Attr("memory_limit: int >= 0 = 0")
1417 .Attr("dtypes: list(type)")
1418 .Attr("container: string = ''")
1419 .Attr("shared_name: string = ''")
1420 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1421 .SetIsStateful();
1422
1423REGISTER_OP("OrderedMapSize")
1424 .Output("size: int32")
1425 .Attr("capacity: int >= 0 = 0")
1426 .Attr("memory_limit: int >= 0 = 0")
1427 .Attr("dtypes: list(type)")
1428 .Attr("container: string = ''")
1429 .Attr("shared_name: string = ''")
1430 .SetShapeFn(tensorflow::shape_inference::ScalarShape)
1431 .SetIsStateful();
1432
1433REGISTER_OP("OrderedMapIncompleteSize")
1434 .Output("size: int32")
1435 .Attr("capacity: int >= 0 = 0")
1436 .Attr("memory_limit: int >= 0 = 0")
1437 .Attr("dtypes: list(type)")
1438 .Attr("container: string = ''")
1439 .Attr("shared_name: string = ''")
1440 .SetShapeFn(tensorflow::shape_inference::ScalarShape)
1441 .SetIsStateful();
1442
1443REGISTER_OP("OrderedMapClear")
1444 .Attr("capacity: int >= 0 = 0")
1445 .Attr("memory_limit: int >= 0 = 0")
1446 .Attr("dtypes: list(type)")
1447 .Attr("container: string = ''")
1448 .Attr("shared_name: string = ''")
1449 .SetShapeFn(tensorflow::shape_inference::NoOutputs)
1450 .SetIsStateful();
1451
1452REGISTER_OP("RecordInput")
1453 .Output("records: string")
1454 .Attr("file_pattern: string")
1455 .Attr("file_random_seed: int = 301")
1456 .Attr("file_shuffle_shift_ratio: float = 0")
1457 .Attr("file_buffer_size: int = 10000")
1458 .Attr("file_parallelism: int = 16")
1459 .Attr("batch_size: int = 32")
1460 .Attr("compression_type: string = ''")
1461 .SetIsStateful()
1462 .SetShapeFn(shape_inference::UnknownShape);
1463
1464} // namespace tensorflow
1465