1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/op_def_builder.h" |
19 | #include "tensorflow/core/framework/shape_inference.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | using shape_inference::DimensionHandle; |
24 | using shape_inference::InferenceContext; |
25 | using shape_inference::ShapeHandle; |
26 | |
27 | namespace { |
28 | |
29 | Status DequeueManyV2Shape(InferenceContext* c, ShapeHandle n_shape) { |
30 | auto* t = c->input_handle_shapes_and_types(0); |
31 | if (t != nullptr && t->size() == c->num_outputs()) { |
32 | for (int i = 0; i < c->num_outputs(); ++i) { |
33 | ShapeHandle combined_shape; |
34 | TF_RETURN_IF_ERROR( |
35 | c->Concatenate(n_shape, (*t)[i].shape, &combined_shape)); |
36 | c->set_output(i, combined_shape); |
37 | } |
38 | return OkStatus(); |
39 | } else { |
40 | return shape_inference::UnknownShape(c); |
41 | } |
42 | } |
43 | |
44 | } // namespace |
45 | |
46 | // -------------------------------------------------------------------------- |
47 | |
48 | REGISTER_OP("DynamicPartition" ) |
49 | .Input("data: T" ) |
50 | .Input("partitions: int32" ) |
51 | .Output("outputs: num_partitions * T" ) |
52 | .Attr("num_partitions: int" ) |
53 | .Attr("T: type" ) |
54 | .SetShapeFn([](InferenceContext* c) { |
55 | int64_t num_partitions; |
56 | TF_RETURN_IF_ERROR(c->GetAttr("num_partitions" , &num_partitions)); |
57 | |
58 | ShapeHandle data_shape = c->input(0); |
59 | ShapeHandle partitions_shape = c->input(1); |
60 | |
61 | if (!c->RankKnown(partitions_shape)) { |
62 | return shape_inference::UnknownShape(c); |
63 | } |
64 | |
65 | const int64_t rank = c->Rank(partitions_shape); |
66 | |
67 | // data shape must start with partitions_shape |
68 | ShapeHandle unused; |
69 | TF_RETURN_IF_ERROR( |
70 | c->MergePrefix(data_shape, partitions_shape, &unused, &unused)); |
71 | |
72 | // The partition shape is dynamic in the 0th dimension, and matches |
73 | // data_shape in the remaining dimensions. |
74 | ShapeHandle unknown_dim0 = c->MakeShape({c->UnknownDim()}); |
75 | |
76 | ShapeHandle data_suffix_shape; |
77 | TF_RETURN_IF_ERROR(c->Subshape(data_shape, rank, &data_suffix_shape)); |
78 | ShapeHandle result_shape; |
79 | TF_RETURN_IF_ERROR( |
80 | c->Concatenate(unknown_dim0, data_suffix_shape, &result_shape)); |
81 | |
82 | for (int i = 0; i < c->num_outputs(); ++i) { |
83 | c->set_output(i, result_shape); |
84 | } |
85 | |
86 | return OkStatus(); |
87 | }); |
88 | |
89 | namespace { |
90 | |
91 | Status DynamicStitchShapeFunction(InferenceContext* c) { |
92 | int32_t num_partitions; |
93 | TF_RETURN_IF_ERROR(c->GetAttr("N" , &num_partitions)); |
94 | |
95 | bool all_indices_constant = true; |
96 | int32_t max_index = -1; |
97 | ShapeHandle = c->UnknownShape(); |
98 | for (int i = 0; i < num_partitions; ++i) { |
99 | const Tensor* indices_t = c->input_tensor(i); |
100 | if (indices_t == nullptr) { |
101 | all_indices_constant = false; |
102 | } |
103 | |
104 | ShapeHandle indices_shape = c->input(i); |
105 | ShapeHandle data_shape = c->input(i + num_partitions); |
106 | if (!c->RankKnown(indices_shape)) { |
107 | continue; |
108 | } |
109 | const int64_t indices_rank = c->Rank(indices_shape); |
110 | |
111 | // Assert that data_shape starts with indices_shape. |
112 | ShapeHandle unused; |
113 | TF_RETURN_IF_ERROR( |
114 | c->MergePrefix(data_shape, indices_shape, &unused, &unused)); |
115 | |
116 | // The rest belongs to output. |
117 | ShapeHandle rest; |
118 | TF_RETURN_IF_ERROR(c->Subshape(data_shape, indices_rank, &rest)); |
119 | TF_RETURN_IF_ERROR(c->Merge(extra_shape, rest, &extra_shape)); |
120 | |
121 | if (indices_t != nullptr) { |
122 | // The length is based on the highest index from flattened indices. |
123 | const int32* indices = indices_t->flat<int32>().data(); |
124 | int64_t count = indices_t->NumElements(); |
125 | for (int64_t i = 0; i < count; ++i) { |
126 | if (indices[i] > max_index) { |
127 | max_index = indices[i]; |
128 | } |
129 | } |
130 | } |
131 | } |
132 | |
133 | ShapeHandle output_shape = c->Vector( |
134 | all_indices_constant ? c->MakeDim(max_index + 1) : c->UnknownDim()); |
135 | TF_RETURN_IF_ERROR(c->Concatenate(output_shape, extra_shape, &output_shape)); |
136 | c->set_output(0, output_shape); |
137 | return OkStatus(); |
138 | } |
139 | |
140 | } // namespace |
141 | |
142 | REGISTER_OP("DynamicStitch" ) |
143 | .Input("indices: N * int32" ) |
144 | .Input("data: N * T" ) |
145 | .Output("merged: T" ) |
146 | .Attr("N : int >= 1" ) |
147 | .Attr("T : type" ) |
148 | .SetShapeFn(DynamicStitchShapeFunction); |
149 | |
150 | REGISTER_OP("ParallelDynamicStitch" ) |
151 | .Input("indices: N * int32" ) |
152 | .Input("data: N * T" ) |
153 | .Output("merged: T" ) |
154 | .Attr("N : int >= 1" ) |
155 | .Attr("T : type" ) |
156 | .SetShapeFn(DynamicStitchShapeFunction); |
157 | |
158 | // -------------------------------------------------------------------------- |
159 | |
160 | namespace { |
161 | Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { |
162 | ShapeHandle handle; |
163 | DimensionHandle unused_handle; |
164 | for (int i = 0; i < c->num_inputs(); ++i) { |
165 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle)); |
166 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle)); |
167 | } |
168 | for (int i = 0; i < c->num_outputs(); ++i) { |
169 | c->set_output(i, c->Scalar()); |
170 | } |
171 | return OkStatus(); |
172 | } |
173 | |
174 | Status TwoElementOutput(InferenceContext* c) { |
175 | c->set_output(0, c->Vector(2)); |
176 | return OkStatus(); |
177 | } |
178 | } // namespace |
179 | |
180 | REGISTER_OP("RandomShuffleQueue" ) |
181 | .Output("handle: Ref(string)" ) |
182 | .Attr("component_types: list(type) >= 1" ) |
183 | .Attr("shapes: list(shape) >= 0 = []" ) |
184 | .Attr("capacity: int = -1" ) |
185 | .Attr("min_after_dequeue: int = 0" ) |
186 | .Attr("seed: int = 0" ) |
187 | .Attr("seed2: int = 0" ) |
188 | .Attr("container: string = ''" ) |
189 | .Attr("shared_name: string = ''" ) |
190 | .SetIsStateful() |
191 | .SetShapeFn(TwoElementOutput); |
192 | |
193 | REGISTER_OP("RandomShuffleQueueV2" ) |
194 | .Output("handle: resource" ) |
195 | .Attr("component_types: list(type) >= 1" ) |
196 | .Attr("shapes: list(shape) >= 0 = []" ) |
197 | .Attr("capacity: int = -1" ) |
198 | .Attr("min_after_dequeue: int = 0" ) |
199 | .Attr("seed: int = 0" ) |
200 | .Attr("seed2: int = 0" ) |
201 | .Attr("container: string = ''" ) |
202 | .Attr("shared_name: string = ''" ) |
203 | .SetIsStateful() |
204 | .SetShapeFn(shape_inference::ScalarShape); |
205 | |
206 | REGISTER_OP("FIFOQueue" ) |
207 | .Output("handle: Ref(string)" ) |
208 | .Attr("component_types: list(type) >= 1" ) |
209 | .Attr("shapes: list(shape) >= 0 = []" ) |
210 | .Attr("capacity: int = -1" ) |
211 | .Attr("container: string = ''" ) |
212 | .Attr("shared_name: string = ''" ) |
213 | .SetIsStateful() |
214 | .SetShapeFn(TwoElementOutput); |
215 | |
216 | REGISTER_OP("FIFOQueueV2" ) |
217 | .Output("handle: resource" ) |
218 | .Attr("component_types: list(type) >= 1" ) |
219 | .Attr("shapes: list(shape) >= 0 = []" ) |
220 | .Attr("capacity: int = -1" ) |
221 | .Attr("container: string = ''" ) |
222 | .Attr("shared_name: string = ''" ) |
223 | .SetIsStateful() |
224 | .SetShapeFn(shape_inference::ScalarShape); |
225 | |
226 | REGISTER_OP("PaddingFIFOQueue" ) |
227 | .Output("handle: Ref(string)" ) |
228 | .Attr("component_types: list(type) >= 1" ) |
229 | .Attr("shapes: list(shape) >= 0 = []" ) |
230 | .Attr("capacity: int = -1" ) |
231 | .Attr("container: string = ''" ) |
232 | .Attr("shared_name: string = ''" ) |
233 | .SetIsStateful() |
234 | .SetShapeFn(TwoElementOutput); |
235 | |
236 | REGISTER_OP("PaddingFIFOQueueV2" ) |
237 | .Output("handle: resource" ) |
238 | .Attr("component_types: list(type) >= 1" ) |
239 | .Attr("shapes: list(shape) >= 0 = []" ) |
240 | .Attr("capacity: int = -1" ) |
241 | .Attr("container: string = ''" ) |
242 | .Attr("shared_name: string = ''" ) |
243 | .SetIsStateful() |
244 | .SetShapeFn(shape_inference::ScalarShape); |
245 | |
246 | REGISTER_OP("PriorityQueue" ) |
247 | .Output("handle: Ref(string)" ) |
248 | .Attr("component_types: list(type) >= 0 = []" ) |
249 | .Attr("shapes: list(shape) >= 0" ) |
250 | .Attr("capacity: int = -1" ) |
251 | .Attr("container: string = ''" ) |
252 | .Attr("shared_name: string = ''" ) |
253 | .SetIsStateful() |
254 | .SetShapeFn(TwoElementOutput); |
255 | |
256 | REGISTER_OP("PriorityQueueV2" ) |
257 | .Output("handle: resource" ) |
258 | .Attr("component_types: list(type) >= 0 = []" ) |
259 | .Attr("shapes: list(shape) >= 0" ) |
260 | .Attr("capacity: int = -1" ) |
261 | .Attr("container: string = ''" ) |
262 | .Attr("shared_name: string = ''" ) |
263 | .SetIsStateful() |
264 | .SetShapeFn(shape_inference::ScalarShape); |
265 | |
266 | REGISTER_OP("FakeQueue" ) |
267 | .Input("resource: resource" ) |
268 | .Output("handle: Ref(string)" ) |
269 | .SetIsStateful() |
270 | .SetShapeFn(TwoElementOutput); |
271 | |
272 | REGISTER_OP("QueueEnqueue" ) |
273 | .Input("handle: Ref(string)" ) |
274 | .Input("components: Tcomponents" ) |
275 | .Attr("Tcomponents: list(type) >= 1" ) |
276 | .Attr("timeout_ms: int = -1" ) |
277 | .SetShapeFn(shape_inference::UnknownShape); |
278 | |
279 | REGISTER_OP("QueueEnqueueV2" ) |
280 | .Input("handle: resource" ) |
281 | .Input("components: Tcomponents" ) |
282 | .Attr("Tcomponents: list(type) >= 1" ) |
283 | .Attr("timeout_ms: int = -1" ) |
284 | .SetShapeFn(shape_inference::UnknownShape); |
285 | |
286 | REGISTER_OP("QueueEnqueueMany" ) |
287 | .Input("handle: Ref(string)" ) |
288 | .Input("components: Tcomponents" ) |
289 | .Attr("Tcomponents: list(type) >= 1" ) |
290 | .Attr("timeout_ms: int = -1" ) |
291 | .SetShapeFn(shape_inference::UnknownShape); |
292 | |
293 | REGISTER_OP("QueueEnqueueManyV2" ) |
294 | .Input("handle: resource" ) |
295 | .Input("components: Tcomponents" ) |
296 | .Attr("Tcomponents: list(type) >= 1" ) |
297 | .Attr("timeout_ms: int = -1" ) |
298 | .SetShapeFn(shape_inference::UnknownShape); |
299 | |
300 | REGISTER_OP("QueueDequeue" ) |
301 | .Input("handle: Ref(string)" ) |
302 | .Output("components: component_types" ) |
303 | .Attr("component_types: list(type) >= 1" ) |
304 | .Attr("timeout_ms: int = -1" ) |
305 | .SetShapeFn(shape_inference::UnknownShape); |
306 | |
307 | REGISTER_OP("QueueDequeueV2" ) |
308 | .Input("handle: resource" ) |
309 | .Output("components: component_types" ) |
310 | .Attr("component_types: list(type) >= 1" ) |
311 | .Attr("timeout_ms: int = -1" ) |
312 | .SetShapeFn([](InferenceContext* c) { |
313 | auto* t = c->input_handle_shapes_and_types(0); |
314 | if (t != nullptr && t->size() == c->num_outputs()) { |
315 | for (int i = 0; i < c->num_outputs(); ++i) { |
316 | c->set_output(i, (*t)[i].shape); |
317 | } |
318 | return OkStatus(); |
319 | } else { |
320 | return shape_inference::UnknownShape(c); |
321 | } |
322 | }); |
323 | |
324 | REGISTER_OP("QueueDequeueMany" ) |
325 | .Input("handle: Ref(string)" ) |
326 | .Input("n: int32" ) |
327 | .Output("components: component_types" ) |
328 | .Attr("component_types: list(type) >= 1" ) |
329 | .Attr("timeout_ms: int = -1" ) |
330 | .SetShapeFn(shape_inference::UnknownShape); |
331 | |
332 | REGISTER_OP("QueueDequeueManyV2" ) |
333 | .Input("handle: resource" ) |
334 | .Input("n: int32" ) |
335 | .Output("components: component_types" ) |
336 | .Attr("component_types: list(type) >= 1" ) |
337 | .Attr("timeout_ms: int = -1" ) |
338 | .SetShapeFn([](InferenceContext* c) { |
339 | ShapeHandle n_shape; |
340 | if (c->input_tensor(1) == nullptr) { |
341 | n_shape = c->Vector(InferenceContext::kUnknownDim); |
342 | } else { |
343 | const int32_t n = c->input_tensor(1)->scalar<int32>()(); |
344 | if (n < 0) { |
345 | return errors::InvalidArgument("Input 'n' must be >= 0, but is " , n); |
346 | } |
347 | n_shape = c->Vector(n); |
348 | } |
349 | return DequeueManyV2Shape(c, n_shape); |
350 | }); |
351 | |
352 | REGISTER_OP("QueueDequeueUpTo" ) |
353 | .Input("handle: Ref(string)" ) |
354 | .Input("n: int32" ) |
355 | .Output("components: component_types" ) |
356 | .Attr("component_types: list(type) >= 1" ) |
357 | .Attr("timeout_ms: int = -1" ) |
358 | .SetShapeFn(shape_inference::UnknownShape); |
359 | |
360 | REGISTER_OP("QueueDequeueUpToV2" ) |
361 | .Input("handle: resource" ) |
362 | .Input("n: int32" ) |
363 | .Output("components: component_types" ) |
364 | .Attr("component_types: list(type) >= 1" ) |
365 | .Attr("timeout_ms: int = -1" ) |
366 | .SetShapeFn([](InferenceContext* c) { |
367 | return DequeueManyV2Shape(c, c->Vector(InferenceContext::kUnknownDim)); |
368 | }); |
369 | |
370 | REGISTER_OP("QueueClose" ) |
371 | .Input("handle: Ref(string)" ) |
372 | .SetShapeFn(TwoElementVectorInputsAndScalarOutputs) |
373 | .Attr("cancel_pending_enqueues: bool = false" ); |
374 | |
375 | REGISTER_OP("QueueCloseV2" ) |
376 | .Input("handle: resource" ) |
377 | .SetShapeFn(shape_inference::NoOutputs) |
378 | .Attr("cancel_pending_enqueues: bool = false" ); |
379 | |
380 | REGISTER_OP("QueueIsClosed" ) |
381 | .Input("handle: Ref(string)" ) |
382 | .Output("is_closed: bool" ) |
383 | .SetShapeFn(shape_inference::ScalarShape); |
384 | |
385 | REGISTER_OP("QueueIsClosedV2" ) |
386 | .Input("handle: resource" ) |
387 | .Output("is_closed: bool" ) |
388 | .SetShapeFn(shape_inference::ScalarShape); |
389 | |
390 | REGISTER_OP("QueueSize" ) |
391 | .Input("handle: Ref(string)" ) |
392 | .Output("size: int32" ) |
393 | .SetShapeFn(TwoElementVectorInputsAndScalarOutputs); |
394 | |
395 | REGISTER_OP("QueueSizeV2" ) |
396 | .Input("handle: resource" ) |
397 | .Output("size: int32" ) |
398 | .SetShapeFn(shape_inference::UnchangedShape); |
399 | |
400 | // -------------------------------------------------------------------------- |
401 | |
402 | REGISTER_OP("AccumulatorNumAccumulated" ) |
403 | .Input("handle: Ref(string)" ) |
404 | .Output("num_accumulated: int32" ) |
405 | .SetShapeFn(shape_inference::ScalarShape); |
406 | |
407 | REGISTER_OP("AccumulatorSetGlobalStep" ) |
408 | .Input("handle: Ref(string)" ) |
409 | .Input("new_global_step: int64" ) |
410 | .SetShapeFn([](InferenceContext* c) { |
411 | ShapeHandle unused; |
412 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
413 | return OkStatus(); |
414 | }); |
415 | |
416 | REGISTER_OP("ConditionalAccumulator" ) |
417 | .Output("handle: Ref(string)" ) |
418 | .Attr("dtype: numbertype" ) |
419 | .Attr("shape: shape" ) |
420 | .Attr("container: string = ''" ) |
421 | .Attr("shared_name: string = ''" ) |
422 | .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' " ) |
423 | .SetIsStateful() |
424 | .SetShapeFn([](InferenceContext* c) { |
425 | c->set_output(0, c->Vector(2)); |
426 | return OkStatus(); |
427 | }); |
428 | |
429 | REGISTER_OP("AccumulatorApplyGradient" ) |
430 | .Input("handle: Ref(string)" ) |
431 | .Input("local_step: int64" ) |
432 | .Input("gradient: dtype" ) |
433 | .Attr("dtype: numbertype" ) |
434 | .SetShapeFn([](InferenceContext* c) { |
435 | ShapeHandle unused; |
436 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
437 | return OkStatus(); |
438 | }); |
439 | |
440 | REGISTER_OP("AccumulatorTakeGradient" ) |
441 | .Input("handle: Ref(string)" ) |
442 | .Input("num_required: int32" ) |
443 | .Output("average: dtype" ) |
444 | .SetShapeFn([](InferenceContext* c) { |
445 | ShapeHandle unused; |
446 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
447 | // Shape of output is the shape of the accumulator referenced |
448 | // by 'handle', but which is not available here, so we lose |
449 | // shape information. |
450 | return shape_inference::UnknownShape(c); |
451 | }) |
452 | .Attr("dtype: numbertype" ); |
453 | |
454 | // -----------------V2 accumulators that use resource ------------------------- |
455 | |
456 | REGISTER_OP("ResourceAccumulatorNumAccumulated" ) |
457 | .Input("handle: resource" ) |
458 | .Output("num_accumulated: int32" ) |
459 | .SetShapeFn(shape_inference::ScalarShape); |
460 | |
461 | REGISTER_OP("ResourceAccumulatorSetGlobalStep" ) |
462 | .Input("handle: resource" ) |
463 | .Input("new_global_step: int64" ) |
464 | .SetShapeFn([](InferenceContext* c) { |
465 | ShapeHandle unused; |
466 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
467 | return OkStatus(); |
468 | }); |
469 | |
470 | REGISTER_OP("ResourceConditionalAccumulator" ) |
471 | .Output("handle: resource" ) |
472 | .Attr("dtype: numbertype" ) |
473 | .Attr("shape: shape" ) |
474 | .Attr("container: string = ''" ) |
475 | .Attr("shared_name: string = ''" ) |
476 | .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' " ) |
477 | .SetIsStateful() |
478 | .SetShapeFn([](InferenceContext* c) { |
479 | c->set_output(0, c->Vector(2)); |
480 | return OkStatus(); |
481 | }); |
482 | |
483 | REGISTER_OP("ResourceAccumulatorApplyGradient" ) |
484 | .Input("handle: resource" ) |
485 | .Input("local_step: int64" ) |
486 | .Input("gradient: dtype" ) |
487 | .Attr("dtype: numbertype" ) |
488 | .SetShapeFn([](InferenceContext* c) { |
489 | ShapeHandle unused; |
490 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
491 | return OkStatus(); |
492 | }); |
493 | |
494 | REGISTER_OP("ResourceAccumulatorTakeGradient" ) |
495 | .Input("handle: resource" ) |
496 | .Input("num_required: int32" ) |
497 | .Output("average: dtype" ) |
498 | .SetShapeFn([](InferenceContext* c) { |
499 | ShapeHandle unused; |
500 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
501 | // Shape of output is the shape of the accumulator referenced |
502 | // by 'handle', but which is not available here, so we lose |
503 | // shape information. |
504 | return shape_inference::UnknownShape(c); |
505 | }) |
506 | .Attr("dtype: numbertype" ); |
507 | |
508 | // TODO(nponomareva): change these all to use resources. |
509 | REGISTER_OP("SparseConditionalAccumulator" ) |
510 | .Output("handle: Ref(string)" ) |
511 | .Attr("dtype: numbertype" ) |
512 | .Attr("shape: shape" ) |
513 | .Attr("container: string = ''" ) |
514 | .Attr("shared_name: string = ''" ) |
515 | .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' " ) |
516 | .SetIsStateful() |
517 | .SetShapeFn([](InferenceContext* c) { |
518 | c->set_output(0, c->Vector(2)); |
519 | return OkStatus(); |
520 | }); |
521 | |
522 | REGISTER_OP("SparseAccumulatorApplyGradient" ) |
523 | .Input("handle: Ref(string)" ) |
524 | .Input("local_step: int64" ) |
525 | .Input("gradient_indices: int64" ) |
526 | .Input("gradient_values: dtype" ) |
527 | .Input("gradient_shape: int64" ) |
528 | .Attr("dtype: numbertype" ) |
529 | .Attr("has_known_shape: bool" ) |
530 | .SetShapeFn([](InferenceContext* c) { |
531 | ShapeHandle unused; |
532 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
533 | return OkStatus(); |
534 | }); |
535 | |
536 | REGISTER_OP("SparseAccumulatorTakeGradient" ) |
537 | .Input("handle: Ref(string)" ) |
538 | .Input("num_required: int32" ) |
539 | .Output("indices: int64" ) |
540 | .Output("values: dtype" ) |
541 | .Output("shape: int64" ) |
542 | .Attr("dtype: numbertype" ) |
543 | .SetShapeFn([](InferenceContext* c) { |
544 | ShapeHandle unused; |
545 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
546 | // Shape of output is the shape of the accumulator referenced |
547 | // by 'handle', but which is not available here, so we lose |
548 | // shape information. |
549 | return shape_inference::UnknownShape(c); |
550 | }); |
551 | |
552 | // -------------------------------------------------------------------------- |
553 | |
554 | REGISTER_OP("StackV2" ) |
555 | .Input("max_size: int32" ) |
556 | .Output("handle: resource" ) |
557 | .Attr("elem_type: type" ) |
558 | .Attr("stack_name: string = ''" ) |
559 | .SetIsStateful() |
560 | .SetShapeFn(TwoElementOutput); |
561 | |
562 | REGISTER_OP("StackPushV2" ) |
563 | .Input("handle: resource" ) |
564 | .Input("elem: T" ) |
565 | .Output("output: T" ) |
566 | .Attr("T: type" ) |
567 | .Attr("swap_memory: bool = false" ) |
568 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
569 | c->set_output(0, c->input(1)); |
570 | return OkStatus(); |
571 | }); |
572 | |
573 | REGISTER_OP("StackPopV2" ) |
574 | .Input("handle: resource" ) |
575 | .Output("elem: elem_type" ) |
576 | .Attr("elem_type: type" ) |
577 | .SetShapeFn(shape_inference::UnknownShape); |
578 | |
579 | REGISTER_OP("StackCloseV2" ) |
580 | .Input("handle: resource" ) |
581 | .SetShapeFn(TwoElementVectorInputsAndScalarOutputs); |
582 | |
583 | // Deprecated ref-typed variants of stack. |
584 | |
585 | REGISTER_OP("Stack" ) |
586 | .Output("handle: Ref(string)" ) |
587 | .Attr("elem_type: type" ) |
588 | .Attr("stack_name: string = ''" ) |
589 | .SetIsStateful() |
590 | .SetShapeFn(TwoElementOutput); |
591 | |
592 | REGISTER_OP("StackPush" ) |
593 | .Input("handle: Ref(string)" ) |
594 | .Input("elem: T" ) |
595 | .Output("output: T" ) |
596 | .Attr("T: type" ) |
597 | .Attr("swap_memory: bool = false" ) |
598 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
599 | c->set_output(0, c->input(1)); |
600 | return OkStatus(); |
601 | }); |
602 | |
603 | REGISTER_OP("StackPop" ) |
604 | .Input("handle: Ref(string)" ) |
605 | .Output("elem: elem_type" ) |
606 | .Attr("elem_type: type" ) |
607 | .SetShapeFn(shape_inference::UnknownShape); |
608 | |
609 | REGISTER_OP("StackClose" ) |
610 | .Input("handle: Ref(string)" ) |
611 | .SetShapeFn(TwoElementVectorInputsAndScalarOutputs); |
612 | |
613 | // -------------------------------------------------------------------------- |
614 | |
615 | REGISTER_OP("TensorArrayV3" ) |
616 | .Input("size: int32" ) |
617 | .Attr("dtype: type" ) |
618 | .Attr("element_shape: shape = { unknown_rank: true }" ) |
619 | .Attr("dynamic_size: bool = false" ) |
620 | .Attr("clear_after_read: bool = true" ) |
621 | .Attr("identical_element_shapes: bool = false" ) |
622 | .Attr("tensor_array_name: string = ''" ) |
623 | .Output("handle: resource" ) |
624 | .Output("flow: float" ) |
625 | .SetIsStateful() |
626 | .SetShapeFn([](InferenceContext* c) { |
627 | ShapeHandle unused; |
628 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
629 | c->set_output(0, c->Vector(2)); |
630 | c->set_output(1, c->Scalar()); |
631 | bool identical_shapes; |
632 | TF_RETURN_IF_ERROR( |
633 | c->GetAttr("identical_element_shapes" , &identical_shapes)); |
634 | DataType t; |
635 | TF_RETURN_IF_ERROR(c->GetAttr("dtype" , &t)); |
636 | PartialTensorShape p; |
637 | TF_RETURN_IF_ERROR(c->GetAttr("element_shape" , &p)); |
638 | ShapeHandle s; |
639 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s)); |
640 | if (c->FullyDefined(s) || identical_shapes) { |
641 | c->set_output_handle_shapes_and_types( |
642 | 0, std::vector<shape_inference::ShapeAndType>{{s, t}}); |
643 | } |
644 | return OkStatus(); |
645 | }); |
646 | |
647 | REGISTER_OP("TensorArrayGradV3" ) |
648 | .Input("handle: resource" ) |
649 | .Input("flow_in: float" ) |
650 | .Output("grad_handle: resource" ) |
651 | .Output("flow_out: float" ) |
652 | .Attr("source: string" ) |
653 | .SetIsStateful() |
654 | .SetShapeFn([](InferenceContext* c) { |
655 | ShapeHandle handle; |
656 | DimensionHandle unused_dim; |
657 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
658 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
659 | c->set_output(0, c->Vector(2)); |
660 | c->set_output(1, c->Scalar()); |
661 | if (c->input_handle_shapes_and_types(0)) { |
662 | c->set_output_handle_shapes_and_types( |
663 | 0, *c->input_handle_shapes_and_types(0)); |
664 | } |
665 | return OkStatus(); |
666 | }); |
667 | |
668 | REGISTER_OP("TensorArrayGradWithShape" ) |
669 | .Input("handle: resource" ) |
670 | .Input("flow_in: float" ) |
671 | .Input("shape_to_prepend: int32" ) |
672 | .Output("grad_handle: resource" ) |
673 | .Output("flow_out: float" ) |
674 | .Attr("source: string" ) |
675 | .SetIsStateful() |
676 | .SetShapeFn([](InferenceContext* c) { |
677 | ShapeHandle handle; |
678 | DimensionHandle unused_dim; |
679 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
680 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
681 | c->set_output(0, c->Vector(2)); |
682 | c->set_output(1, c->Scalar()); |
683 | auto* shape_and_type = c->input_handle_shapes_and_types(0); |
684 | if (shape_and_type) { |
685 | auto input_shape = (*shape_and_type)[0].shape; |
686 | auto dtype = (*shape_and_type)[0].dtype; |
687 | // Note that shape_to_preped is a rank 1 Tensor representing a shape. |
688 | // The size of dimension 0 is the number of dimensions we need to add to |
689 | // output shape. |
690 | int64_t prepend_rank = c->Value(c->Dim(c->input(2), 0)); |
691 | if (c->RankKnown(input_shape) && |
692 | prepend_rank != InferenceContext::kUnknownDim) { |
693 | int32_t input_rank = c->Rank(input_shape); |
694 | std::vector<DimensionHandle> dims; |
695 | dims.reserve(prepend_rank + input_rank); |
696 | for (int i = 0; i < prepend_rank; ++i) { |
697 | dims.push_back(c->UnknownDim()); |
698 | } |
699 | for (int i = 0; i < input_rank; ++i) { |
700 | dims.push_back(c->Dim(input_shape, i)); |
701 | } |
702 | c->set_output_handle_shapes_and_types(0, |
703 | {{c->MakeShape(dims), dtype}}); |
704 | } else { |
705 | c->set_output_handle_shapes_and_types(0, |
706 | {{c->UnknownShape(), dtype}}); |
707 | } |
708 | } |
709 | return OkStatus(); |
710 | }); |
711 | |
712 | REGISTER_OP("TensorArrayWriteV3" ) |
713 | .Input("handle: resource" ) |
714 | .Input("index: int32" ) |
715 | .Input("value: T" ) |
716 | .Input("flow_in: float" ) |
717 | .Output("flow_out: float" ) |
718 | .Attr("T: type" ) |
719 | .SetShapeFn([](InferenceContext* c) { |
720 | ShapeHandle handle; |
721 | DimensionHandle unused_dim; |
722 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
723 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
724 | |
725 | ShapeHandle unused; |
726 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
727 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
728 | |
729 | auto* handle_data = c->input_handle_shapes_and_types(0); |
730 | if (handle_data != nullptr && !handle_data->empty()) { |
731 | shape_inference::ShapeAndType shape_and_type = (*handle_data)[0]; |
732 | ShapeHandle value_shape = c->input(2); |
733 | TF_RETURN_IF_ERROR( |
734 | c->Merge(shape_and_type.shape, value_shape, &unused)); |
735 | } |
736 | |
737 | return shape_inference::ScalarShape(c); |
738 | }); |
739 | |
740 | REGISTER_OP("TensorArrayReadV3" ) |
741 | .Input("handle: resource" ) |
742 | .Input("index: int32" ) |
743 | .Input("flow_in: float" ) |
744 | .Output("value: dtype" ) |
745 | .Attr("dtype: type" ) |
746 | .SetShapeFn([](InferenceContext* c) { |
747 | ShapeHandle handle; |
748 | DimensionHandle unused_dim; |
749 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
750 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
751 | ShapeHandle unused; |
752 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
753 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
754 | auto shapes = c->input_handle_shapes_and_types(0); |
755 | if (shapes != nullptr && !shapes->empty()) { |
756 | ShapeHandle tensor_shape = shapes->at(0).shape; |
757 | c->set_output(0, tensor_shape); |
758 | return OkStatus(); |
759 | } else { |
760 | return shape_inference::UnknownShape(c); |
761 | } |
762 | }); |
763 | |
764 | REGISTER_OP("TensorArrayGatherV3" ) |
765 | .Input("handle: resource" ) |
766 | .Input("indices: int32" ) |
767 | .Input("flow_in: float" ) |
768 | .Output("value: dtype" ) |
769 | .Attr("dtype: type" ) |
770 | .Attr("element_shape: shape = { unknown_rank: true }" ) |
771 | .SetShapeFn([](InferenceContext* c) { |
772 | ShapeHandle indices; |
773 | ShapeHandle unused; |
774 | DimensionHandle unused_dim; |
775 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); |
776 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices)); |
777 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim)); |
778 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
779 | auto shapes = c->input_handle_shapes_and_types(0); |
780 | if (shapes != nullptr && !shapes->empty()) { |
781 | ShapeHandle tensor_shape = shapes->at(0).shape; |
782 | ShapeHandle output_shape; |
783 | TF_RETURN_IF_ERROR( |
784 | c->Concatenate(indices, tensor_shape, &output_shape)); |
785 | c->set_output(0, output_shape); |
786 | return OkStatus(); |
787 | } else { |
788 | PartialTensorShape p; |
789 | TF_RETURN_IF_ERROR(c->GetAttr("element_shape" , &p)); |
790 | ShapeHandle s; |
791 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s)); |
792 | ShapeHandle output_shape; |
793 | TF_RETURN_IF_ERROR(c->Concatenate(indices, s, &output_shape)); |
794 | c->set_output(0, output_shape); |
795 | return OkStatus(); |
796 | } |
797 | }); |
798 | |
799 | REGISTER_OP("TensorArrayScatterV3" ) |
800 | .Input("handle: resource" ) |
801 | .Input("indices: int32" ) |
802 | .Input("value: T" ) |
803 | .Input("flow_in: float" ) |
804 | .Output("flow_out: float" ) |
805 | .Attr("T: type" ) |
806 | .SetShapeFn([](InferenceContext* c) { |
807 | ShapeHandle indices; |
808 | ShapeHandle unused; |
809 | DimensionHandle unused_dim; |
810 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); |
811 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices)); |
812 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim)); |
813 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
814 | ShapeHandle value_shape; |
815 | // Assert that the length of the indices tensor is equal to the first |
816 | // dimension of the value tensor. |
817 | TF_RETURN_IF_ERROR( |
818 | c->MergePrefix(c->input(2), indices, &value_shape, &indices)); |
819 | auto shapes = c->input_handle_shapes_and_types(0); |
820 | if (shapes != nullptr && !shapes->empty()) { |
821 | ShapeHandle tensor_shape = shapes->at(0).shape; |
822 | ShapeHandle fed_shape; |
823 | TF_RETURN_IF_ERROR(c->Subshape(value_shape, 1, &fed_shape)); |
824 | TF_RETURN_IF_ERROR(c->Merge(tensor_shape, fed_shape, &fed_shape)); |
825 | } |
826 | return shape_inference::ScalarShape(c); |
827 | }); |
828 | |
829 | REGISTER_OP("TensorArrayConcatV3" ) |
830 | .Input("handle: resource" ) |
831 | .Input("flow_in: float" ) |
832 | .Output("value: dtype" ) |
833 | .Output("lengths: int64" ) |
834 | .Attr("dtype: type" ) |
835 | .Attr("element_shape_except0: shape = { unknown_rank: true }" ) |
836 | .SetShapeFn([](InferenceContext* c) { |
837 | ShapeHandle handle; |
838 | DimensionHandle unused_dim; |
839 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
840 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
841 | ShapeHandle unused; |
842 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
843 | c->set_output(0, c->UnknownShape()); |
844 | c->set_output(1, c->Vector(c->UnknownDim())); |
845 | return OkStatus(); |
846 | }); |
847 | |
848 | REGISTER_OP("TensorArraySplitV3" ) |
849 | .Input("handle: resource" ) |
850 | .Input("value: T" ) |
851 | .Input("lengths: int64" ) |
852 | .Input("flow_in: float" ) |
853 | .Output("flow_out: float" ) |
854 | .Attr("T: type" ) |
855 | .SetShapeFn([](InferenceContext* c) { |
856 | ShapeHandle handle; |
857 | DimensionHandle unused_dim; |
858 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
859 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
860 | ShapeHandle unused; |
861 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
862 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
863 | return shape_inference::ScalarShape(c); |
864 | }); |
865 | |
866 | REGISTER_OP("TensorArraySizeV3" ) |
867 | .Input("handle: resource" ) |
868 | .Input("flow_in: float" ) |
869 | .Output("size: int32" ) |
870 | .SetShapeFn([](InferenceContext* c) { |
871 | ShapeHandle handle; |
872 | DimensionHandle unused_dim; |
873 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
874 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
875 | return shape_inference::ScalarShape(c); |
876 | }); |
877 | |
878 | REGISTER_OP("TensorArrayCloseV3" ) |
879 | .Input("handle: resource" ) |
880 | .SetShapeFn([](InferenceContext* c) { |
881 | ShapeHandle handle; |
882 | DimensionHandle unused_dim; |
883 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
884 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
885 | return OkStatus(); |
886 | }); |
887 | |
888 | // -------------------------------------------------------------------------- |
889 | |
890 | // Deprecated TensorArray methods |
891 | |
892 | REGISTER_OP("TensorArray" ) |
893 | .Input("size: int32" ) |
894 | .Attr("dtype: type" ) |
895 | .Attr("dynamic_size: bool = false" ) |
896 | .Attr("clear_after_read: bool = true" ) |
897 | .Attr("tensor_array_name: string = ''" ) |
898 | .Attr("element_shape: shape = { unknown_rank: true }" ) |
899 | .Output("handle: Ref(string)" ) |
900 | .SetIsStateful() |
901 | .SetShapeFn(shape_inference::UnknownShape) |
902 | .Deprecated(16, "Use TensorArrayV3" ); |
903 | REGISTER_OP("TensorArrayV2" ) |
904 | .Input("size: int32" ) |
905 | .Attr("dtype: type" ) |
906 | .Attr("element_shape: shape = { unknown_rank: true }" ) |
907 | .Attr("dynamic_size: bool = false" ) |
908 | .Attr("clear_after_read: bool = true" ) |
909 | .Attr("tensor_array_name: string = ''" ) |
910 | .Output("handle: string" ) |
911 | .SetIsStateful() |
912 | .SetShapeFn([](InferenceContext* c) { |
913 | ShapeHandle unused; |
914 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
915 | c->set_output(0, c->Vector(2)); |
916 | return OkStatus(); |
917 | }) |
918 | .Deprecated(26, "Use TensorArrayV3" ); |
919 | REGISTER_OP("TensorArrayGrad" ) |
920 | .Input("handle: string" ) |
921 | .Input("flow_in: float" ) |
922 | .Output("grad_handle: Ref(string)" ) |
923 | .Attr("source: string" ) |
924 | .SetIsStateful() |
925 | .SetShapeFn(shape_inference::UnknownShape) |
926 | .Deprecated(16, "Use TensorArrayGradV3" ); |
927 | REGISTER_OP("TensorArrayGradV2" ) |
928 | .Input("handle: string" ) |
929 | .Input("flow_in: float" ) |
930 | .Output("grad_handle: string" ) |
931 | .Attr("source: string" ) |
932 | .SetIsStateful() |
933 | .SetShapeFn([](InferenceContext* c) { |
934 | ShapeHandle handle; |
935 | DimensionHandle unused_dim; |
936 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
937 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
938 | c->set_output(0, c->Vector(2)); |
939 | return OkStatus(); |
940 | }) |
941 | .Deprecated(26, "Use TensorArrayGradV3" ); |
942 | REGISTER_OP("TensorArrayWrite" ) |
943 | .Input("handle: Ref(string)" ) |
944 | .Input("index: int32" ) |
945 | .Input("value: T" ) |
946 | .Input("flow_in: float" ) |
947 | .Output("flow_out: float" ) |
948 | .Attr("T: type" ) |
949 | .SetShapeFn(shape_inference::UnknownShape) |
950 | .Deprecated(16, "Use TensorArrayWriteV3" ); |
951 | REGISTER_OP("TensorArrayWriteV2" ) |
952 | .Input("handle: string" ) |
953 | .Input("index: int32" ) |
954 | .Input("value: T" ) |
955 | .Input("flow_in: float" ) |
956 | .Output("flow_out: float" ) |
957 | .Attr("T: type" ) |
958 | .SetShapeFn([](InferenceContext* c) { |
959 | ShapeHandle handle; |
960 | DimensionHandle unused_dim; |
961 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
962 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
963 | |
964 | ShapeHandle unused; |
965 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
966 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
967 | return shape_inference::ScalarShape(c); |
968 | }) |
969 | .Deprecated(26, "Use TensorArrayWriteV3" ); |
970 | REGISTER_OP("TensorArrayRead" ) |
971 | .Input("handle: Ref(string)" ) |
972 | .Input("index: int32" ) |
973 | .Input("flow_in: float" ) |
974 | .Output("value: dtype" ) |
975 | .Attr("dtype: type" ) |
976 | .SetShapeFn(shape_inference::UnknownShape) |
977 | .Deprecated(16, "Use TensorArrayReadV3" ); |
978 | REGISTER_OP("TensorArrayReadV2" ) |
979 | .Input("handle: string" ) |
980 | .Input("index: int32" ) |
981 | .Input("flow_in: float" ) |
982 | .Output("value: dtype" ) |
983 | .Attr("dtype: type" ) |
984 | .SetShapeFn([](InferenceContext* c) { |
985 | ShapeHandle handle; |
986 | DimensionHandle unused_dim; |
987 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
988 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
989 | ShapeHandle unused; |
990 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
991 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
992 | return shape_inference::UnknownShape(c); |
993 | }) |
994 | .Deprecated(26, "Use TensorArrayReadV3" ); |
995 | REGISTER_OP("TensorArrayPack" ) |
996 | .Input("handle: Ref(string)" ) |
997 | .Input("flow_in: float" ) |
998 | .Output("value: dtype" ) |
999 | .Attr("dtype: type" ) |
1000 | .Attr("element_shape: shape = { unknown_rank: true }" ) |
1001 | .SetShapeFn(shape_inference::UnknownShape) |
1002 | .Deprecated(16, "Use TensorArrayGatherV3 with RangeOp" ); |
1003 | REGISTER_OP("TensorArrayUnpack" ) |
1004 | .Input("handle: Ref(string)" ) |
1005 | .Input("value: T" ) |
1006 | .Input("flow_in: float" ) |
1007 | .Output("flow_out: float" ) |
1008 | .Attr("T: type" ) |
1009 | .SetShapeFn(shape_inference::UnknownShape) |
1010 | .Deprecated(20, "Use TensorArrayScatterV3 with RangeOp" ); |
1011 | REGISTER_OP("TensorArrayGather" ) |
1012 | .Input("handle: Ref(string)" ) |
1013 | .Input("indices: int32" ) |
1014 | .Input("flow_in: float" ) |
1015 | .Output("value: dtype" ) |
1016 | .Attr("dtype: type" ) |
1017 | .Attr("element_shape: shape = { unknown_rank: true }" ) |
1018 | .SetShapeFn(shape_inference::UnknownShape) |
1019 | .Deprecated(16, "Use TensorArrayGatherV3" ); |
1020 | REGISTER_OP("TensorArrayGatherV2" ) |
1021 | .Input("handle: string" ) |
1022 | .Input("indices: int32" ) |
1023 | .Input("flow_in: float" ) |
1024 | .Output("value: dtype" ) |
1025 | .Attr("dtype: type" ) |
1026 | .Attr("element_shape: shape = { unknown_rank: true }" ) |
1027 | .SetShapeFn([](InferenceContext* c) { |
1028 | ShapeHandle unused; |
1029 | DimensionHandle unused_dim; |
1030 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); |
1031 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
1032 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim)); |
1033 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
1034 | return shape_inference::UnknownShape(c); |
1035 | }) |
1036 | .Deprecated(26, "Use TensorArrayGatherV3" ); |
1037 | REGISTER_OP("TensorArrayScatter" ) |
1038 | .Input("handle: Ref(string)" ) |
1039 | .Input("indices: int32" ) |
1040 | .Input("value: T" ) |
1041 | .Input("flow_in: float" ) |
1042 | .Output("flow_out: float" ) |
1043 | .Attr("T: type" ) |
1044 | .SetShapeFn(shape_inference::UnknownShape) |
1045 | .Deprecated(19, "Use TensorArrayGradV3" ); |
1046 | REGISTER_OP("TensorArrayScatterV2" ) |
1047 | .Input("handle: string" ) |
1048 | .Input("indices: int32" ) |
1049 | .Input("value: T" ) |
1050 | .Input("flow_in: float" ) |
1051 | .Output("flow_out: float" ) |
1052 | .Attr("T: type" ) |
1053 | .SetShapeFn([](InferenceContext* c) { |
1054 | ShapeHandle unused; |
1055 | DimensionHandle unused_dim; |
1056 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); |
1057 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
1058 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim)); |
1059 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1060 | return shape_inference::ScalarShape(c); |
1061 | }) |
1062 | .Deprecated(26, "Use TensorArrayScatterV3" ); |
1063 | REGISTER_OP("TensorArrayConcat" ) |
1064 | .Input("handle: Ref(string)" ) |
1065 | .Input("flow_in: float" ) |
1066 | .Output("value: dtype" ) |
1067 | .Output("lengths: int64" ) |
1068 | .Attr("dtype: type" ) |
1069 | .Attr("element_shape_except0: shape = { unknown_rank: true }" ) |
1070 | .SetShapeFn(shape_inference::UnknownShape) |
1071 | .Deprecated(16, "Use TensorArrayGradV3" ); |
1072 | REGISTER_OP("TensorArrayConcatV2" ) |
1073 | .Input("handle: string" ) |
1074 | .Input("flow_in: float" ) |
1075 | .Output("value: dtype" ) |
1076 | .Output("lengths: int64" ) |
1077 | .Attr("dtype: type" ) |
1078 | .Attr("element_shape_except0: shape = { unknown_rank: true }" ) |
1079 | .SetShapeFn([](InferenceContext* c) { |
1080 | ShapeHandle handle; |
1081 | DimensionHandle unused_dim; |
1082 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
1083 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
1084 | ShapeHandle unused; |
1085 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1086 | c->set_output(0, c->UnknownShape()); |
1087 | c->set_output(1, c->Vector(c->UnknownDim())); |
1088 | return OkStatus(); |
1089 | }); |
1090 | REGISTER_OP("TensorArraySplit" ) |
1091 | .Input("handle: Ref(string)" ) |
1092 | .Input("value: T" ) |
1093 | .Input("lengths: int64" ) |
1094 | .Input("flow_in: float" ) |
1095 | .Output("flow_out: float" ) |
1096 | .Attr("T: type" ) |
1097 | .SetShapeFn(shape_inference::UnknownShape) |
1098 | .Deprecated(16, "Use TensorArraySplitV3" ); |
1099 | REGISTER_OP("TensorArraySplitV2" ) |
1100 | .Input("handle: string" ) |
1101 | .Input("value: T" ) |
1102 | .Input("lengths: int64" ) |
1103 | .Input("flow_in: float" ) |
1104 | .Output("flow_out: float" ) |
1105 | .Attr("T: type" ) |
1106 | .SetShapeFn([](InferenceContext* c) { |
1107 | ShapeHandle handle; |
1108 | DimensionHandle unused_dim; |
1109 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
1110 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
1111 | ShapeHandle unused; |
1112 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
1113 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1114 | return shape_inference::ScalarShape(c); |
1115 | }) |
1116 | .Deprecated(26, "Use TensorArraySplitV3" ); |
1117 | REGISTER_OP("TensorArraySize" ) |
1118 | .Input("handle: Ref(string)" ) |
1119 | .Input("flow_in: float" ) |
1120 | .Output("size: int32" ) |
1121 | .SetShapeFn(shape_inference::UnknownShape) |
1122 | .Deprecated(16, "Use TensorArraySizeV3" ); |
1123 | REGISTER_OP("TensorArraySizeV2" ) |
1124 | .Input("handle: string" ) |
1125 | .Input("flow_in: float" ) |
1126 | .Output("size: int32" ) |
1127 | .SetShapeFn([](InferenceContext* c) { |
1128 | ShapeHandle handle; |
1129 | DimensionHandle unused_dim; |
1130 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
1131 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
1132 | return shape_inference::ScalarShape(c); |
1133 | }) |
1134 | .Deprecated(26, "Use TensorArraySizeV3" ); |
1135 | REGISTER_OP("TensorArrayClose" ) |
1136 | .Input("handle: Ref(string)" ) |
1137 | .SetShapeFn([](InferenceContext* c) { return OkStatus(); }) |
1138 | .Deprecated(16, "Use TensorArrayCloseV3" ); |
1139 | REGISTER_OP("TensorArrayCloseV2" ) |
1140 | .Input("handle: string" ) |
1141 | .SetShapeFn([](InferenceContext* c) { |
1142 | ShapeHandle handle; |
1143 | DimensionHandle unused_dim; |
1144 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
1145 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
1146 | return OkStatus(); |
1147 | }) |
1148 | .Deprecated(26, "Use TensorArrayCloseV3" ); |
1149 | |
1150 | // -------------------------------------------------------------------------- |
1151 | |
1152 | REGISTER_OP("Barrier" ) |
1153 | .SetIsStateful() |
1154 | .Output("handle: Ref(string)" ) |
1155 | .Attr("component_types: list(type) >= 1" ) |
1156 | .Attr("shapes: list(shape) >= 0 = []" ) |
1157 | .Attr("capacity: int = -1" ) |
1158 | .Attr("container: string = ''" ) |
1159 | .Attr("shared_name: string = ''" ) |
1160 | .SetShapeFn(TwoElementOutput); |
1161 | |
1162 | REGISTER_OP("BarrierInsertMany" ) |
1163 | .Input("handle: Ref(string)" ) |
1164 | .Input("keys: string" ) |
1165 | .Input("values: T" ) |
1166 | .Attr("T: type" ) |
1167 | .Attr("component_index: int" ) |
1168 | .SetShapeFn([](InferenceContext* c) { |
1169 | ShapeHandle keys = c->input(1); |
1170 | ShapeHandle values = c->input(2); |
1171 | ShapeHandle handle; |
1172 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); |
1173 | DimensionHandle unused_dim; |
1174 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); |
1175 | TF_RETURN_IF_ERROR(c->WithRank(keys, 1, &keys)); |
1176 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); |
1177 | TF_RETURN_IF_ERROR(c->Merge(keys, c->Vector(c->Dim(values, 0)), &handle)); |
1178 | return OkStatus(); |
1179 | }); |
1180 | |
1181 | REGISTER_OP("BarrierTakeMany" ) |
1182 | .Input("handle: Ref(string)" ) |
1183 | .Input("num_elements: int32" ) |
1184 | .Output("indices: int64" ) |
1185 | .Output("keys: string" ) |
1186 | .Output("values: component_types" ) |
1187 | .Attr("component_types: list(type) >= 1" ) |
1188 | .Attr("allow_small_batch: bool = false" ) |
1189 | .Attr("wait_for_incomplete: bool = false" ) |
1190 | .Attr("timeout_ms: int = -1" ) |
1191 | .SetShapeFn(shape_inference::UnknownShape); |
1192 | |
1193 | REGISTER_OP("BarrierClose" ) |
1194 | .Input("handle: Ref(string)" ) |
1195 | .SetShapeFn(TwoElementVectorInputsAndScalarOutputs) |
1196 | .Attr("cancel_pending_enqueues: bool = false" ); |
1197 | |
1198 | REGISTER_OP("BarrierReadySize" ) |
1199 | .Input("handle: Ref(string)" ) |
1200 | .Output("size: int32" ) |
1201 | .SetShapeFn(TwoElementVectorInputsAndScalarOutputs); |
1202 | |
1203 | REGISTER_OP("BarrierIncompleteSize" ) |
1204 | .Input("handle: Ref(string)" ) |
1205 | .Output("size: int32" ) |
1206 | .SetShapeFn(TwoElementVectorInputsAndScalarOutputs); |
1207 | |
1208 | // -------------------------------------------------------------------------- |
1209 | |
1210 | REGISTER_OP("GetSessionHandle" ) |
1211 | .Input("value: T" ) |
1212 | .Output("handle: string" ) |
1213 | .Attr("T: type" ) |
1214 | .SetIsStateful() |
1215 | .SetShapeFn(shape_inference::ScalarShape); |
1216 | |
1217 | REGISTER_OP("GetSessionHandleV2" ) |
1218 | .Input("value: T" ) |
1219 | .Output("handle: resource" ) |
1220 | .Attr("T: type" ) |
1221 | .SetIsStateful() |
1222 | .SetShapeFn(shape_inference::ScalarShape); |
1223 | |
1224 | REGISTER_OP("GetSessionTensor" ) |
1225 | .Input("handle: string" ) |
1226 | .Output("value: dtype" ) |
1227 | .Attr("dtype: type" ) |
1228 | .SetIsStateful() |
1229 | .SetShapeFn([](InferenceContext* c) { |
1230 | ShapeHandle unused; |
1231 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
1232 | return shape_inference::UnknownShape(c); |
1233 | }); |
1234 | |
1235 | REGISTER_OP("DeleteSessionTensor" ) |
1236 | .Input("handle: string" ) |
1237 | .SetIsStateful() |
1238 | .SetShapeFn([](InferenceContext* c) { |
1239 | ShapeHandle unused; |
1240 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
1241 | return OkStatus(); |
1242 | }); |
1243 | |
1244 | REGISTER_OP("Stage" ) |
1245 | .Input("values: dtypes" ) |
1246 | .Attr("capacity: int >= 0 = 0" ) |
1247 | .Attr("memory_limit: int >= 0 = 0" ) |
1248 | .Attr("dtypes: list(type)" ) |
1249 | .Attr("container: string = ''" ) |
1250 | .Attr("shared_name: string = ''" ) |
1251 | .SetShapeFn(shape_inference::UnknownShape) |
1252 | .SetIsStateful(); |
1253 | |
1254 | REGISTER_OP("Unstage" ) |
1255 | .Output("values: dtypes" ) |
1256 | .Attr("capacity: int >= 0 = 0" ) |
1257 | .Attr("memory_limit: int >= 0 = 0" ) |
1258 | .Attr("dtypes: list(type)" ) |
1259 | .Attr("container: string = ''" ) |
1260 | .Attr("shared_name: string = ''" ) |
1261 | .SetShapeFn(shape_inference::UnknownShape) |
1262 | .SetIsStateful(); |
1263 | |
1264 | REGISTER_OP("StagePeek" ) |
1265 | .Input("index: int32" ) |
1266 | .Output("values: dtypes" ) |
1267 | .Attr("capacity: int >= 0 = 0" ) |
1268 | .Attr("memory_limit: int >= 0 = 0" ) |
1269 | .Attr("dtypes: list(type)" ) |
1270 | .Attr("container: string = ''" ) |
1271 | .Attr("shared_name: string = ''" ) |
1272 | .SetShapeFn(shape_inference::UnknownShape) |
1273 | .SetIsStateful(); |
1274 | |
1275 | REGISTER_OP("StageSize" ) |
1276 | .Output("size: int32" ) |
1277 | .Attr("capacity: int >= 0 = 0" ) |
1278 | .Attr("memory_limit: int >= 0 = 0" ) |
1279 | .Attr("dtypes: list(type)" ) |
1280 | .Attr("container: string = ''" ) |
1281 | .Attr("shared_name: string = ''" ) |
1282 | .SetShapeFn(shape_inference::ScalarShape) |
1283 | .SetIsStateful(); |
1284 | |
1285 | REGISTER_OP("StageClear" ) |
1286 | .Attr("capacity: int >= 0 = 0" ) |
1287 | .Attr("memory_limit: int >= 0 = 0" ) |
1288 | .Attr("dtypes: list(type)" ) |
1289 | .Attr("container: string = ''" ) |
1290 | .Attr("shared_name: string = ''" ) |
1291 | .SetShapeFn(shape_inference::UnknownShape) |
1292 | .SetIsStateful(); |
1293 | |
1294 | // UnorderedMap |
1295 | REGISTER_OP("MapStage" ) |
1296 | .Input("key: int64" ) |
1297 | .Input("indices: int32" ) |
1298 | .Input("values: fake_dtypes" ) |
1299 | .Attr("capacity: int >= 0 = 0" ) |
1300 | .Attr("memory_limit: int >= 0 = 0" ) |
1301 | .Attr("dtypes: list(type)" ) |
1302 | .Attr("fake_dtypes: list(type)" ) |
1303 | .Attr("container: string = ''" ) |
1304 | .Attr("shared_name: string = ''" ) |
1305 | .SetShapeFn(tensorflow::shape_inference::NoOutputs) |
1306 | .SetIsStateful(); |
1307 | |
1308 | REGISTER_OP("MapPeek" ) |
1309 | .Input("key: int64" ) |
1310 | .Input("indices: int32" ) |
1311 | .Output("values: dtypes" ) |
1312 | .Attr("capacity: int >= 0 = 0" ) |
1313 | .Attr("memory_limit: int >= 0 = 0" ) |
1314 | .Attr("dtypes: list(type)" ) |
1315 | .Attr("container: string = ''" ) |
1316 | .Attr("shared_name: string = ''" ) |
1317 | .SetShapeFn(tensorflow::shape_inference::UnknownShape) |
1318 | .SetIsStateful(); |
1319 | |
1320 | REGISTER_OP("MapUnstage" ) |
1321 | .Input("key: int64" ) |
1322 | .Input("indices: int32" ) |
1323 | .Output("values: dtypes" ) |
1324 | .Attr("capacity: int >= 0 = 0" ) |
1325 | .Attr("memory_limit: int >= 0 = 0" ) |
1326 | .Attr("dtypes: list(type)" ) |
1327 | .Attr("container: string = ''" ) |
1328 | .Attr("shared_name: string = ''" ) |
1329 | .SetShapeFn(tensorflow::shape_inference::UnknownShape) |
1330 | .SetIsStateful(); |
1331 | |
1332 | REGISTER_OP("MapUnstageNoKey" ) |
1333 | .Input("indices: int32" ) |
1334 | .Output("key: int64" ) |
1335 | .Output("values: dtypes" ) |
1336 | .Attr("capacity: int >= 0 = 0" ) |
1337 | .Attr("memory_limit: int >= 0 = 0" ) |
1338 | .Attr("dtypes: list(type)" ) |
1339 | .Attr("container: string = ''" ) |
1340 | .Attr("shared_name: string = ''" ) |
1341 | .SetShapeFn(tensorflow::shape_inference::UnknownShape) |
1342 | .SetIsStateful(); |
1343 | |
1344 | REGISTER_OP("MapSize" ) |
1345 | .Output("size: int32" ) |
1346 | .Attr("capacity: int >= 0 = 0" ) |
1347 | .Attr("memory_limit: int >= 0 = 0" ) |
1348 | .Attr("dtypes: list(type)" ) |
1349 | .Attr("container: string = ''" ) |
1350 | .Attr("shared_name: string = ''" ) |
1351 | .SetShapeFn(tensorflow::shape_inference::ScalarShape) |
1352 | .SetIsStateful(); |
1353 | |
1354 | REGISTER_OP("MapIncompleteSize" ) |
1355 | .Output("size: int32" ) |
1356 | .Attr("capacity: int >= 0 = 0" ) |
1357 | .Attr("memory_limit: int >= 0 = 0" ) |
1358 | .Attr("dtypes: list(type)" ) |
1359 | .Attr("container: string = ''" ) |
1360 | .Attr("shared_name: string = ''" ) |
1361 | .SetShapeFn(tensorflow::shape_inference::ScalarShape) |
1362 | .SetIsStateful(); |
1363 | |
1364 | REGISTER_OP("MapClear" ) |
1365 | .Attr("capacity: int >= 0 = 0" ) |
1366 | .Attr("memory_limit: int >= 0 = 0" ) |
1367 | .Attr("dtypes: list(type)" ) |
1368 | .Attr("container: string = ''" ) |
1369 | .Attr("shared_name: string = ''" ) |
1370 | .SetShapeFn(tensorflow::shape_inference::NoOutputs) |
1371 | .SetIsStateful(); |
1372 | |
1373 | // OrderedMap |
1374 | REGISTER_OP("OrderedMapStage" ) |
1375 | .Input("key: int64" ) |
1376 | .Input("indices: int32" ) |
1377 | .Input("values: fake_dtypes" ) |
1378 | .Attr("capacity: int >= 0 = 0" ) |
1379 | .Attr("memory_limit: int >= 0 = 0" ) |
1380 | .Attr("dtypes: list(type)" ) |
1381 | .Attr("fake_dtypes: list(type)" ) |
1382 | .Attr("container: string = ''" ) |
1383 | .Attr("shared_name: string = ''" ) |
1384 | .SetShapeFn(tensorflow::shape_inference::NoOutputs) |
1385 | .SetIsStateful(); |
1386 | |
1387 | REGISTER_OP("OrderedMapPeek" ) |
1388 | .Input("key: int64" ) |
1389 | .Input("indices: int32" ) |
1390 | .Output("values: dtypes" ) |
1391 | .Attr("capacity: int >= 0 = 0" ) |
1392 | .Attr("memory_limit: int >= 0 = 0" ) |
1393 | .Attr("dtypes: list(type)" ) |
1394 | .Attr("container: string = ''" ) |
1395 | .Attr("shared_name: string = ''" ) |
1396 | .SetShapeFn(tensorflow::shape_inference::UnknownShape) |
1397 | .SetIsStateful(); |
1398 | |
1399 | REGISTER_OP("OrderedMapUnstage" ) |
1400 | .Input("key: int64" ) |
1401 | .Input("indices: int32" ) |
1402 | .Output("values: dtypes" ) |
1403 | .Attr("capacity: int >= 0 = 0" ) |
1404 | .Attr("memory_limit: int >= 0 = 0" ) |
1405 | .Attr("dtypes: list(type)" ) |
1406 | .Attr("container: string = ''" ) |
1407 | .Attr("shared_name: string = ''" ) |
1408 | .SetShapeFn(tensorflow::shape_inference::UnknownShape) |
1409 | .SetIsStateful(); |
1410 | |
1411 | REGISTER_OP("OrderedMapUnstageNoKey" ) |
1412 | .Input("indices: int32" ) |
1413 | .Output("key: int64" ) |
1414 | .Output("values: dtypes" ) |
1415 | .Attr("capacity: int >= 0 = 0" ) |
1416 | .Attr("memory_limit: int >= 0 = 0" ) |
1417 | .Attr("dtypes: list(type)" ) |
1418 | .Attr("container: string = ''" ) |
1419 | .Attr("shared_name: string = ''" ) |
1420 | .SetShapeFn(tensorflow::shape_inference::UnknownShape) |
1421 | .SetIsStateful(); |
1422 | |
1423 | REGISTER_OP("OrderedMapSize" ) |
1424 | .Output("size: int32" ) |
1425 | .Attr("capacity: int >= 0 = 0" ) |
1426 | .Attr("memory_limit: int >= 0 = 0" ) |
1427 | .Attr("dtypes: list(type)" ) |
1428 | .Attr("container: string = ''" ) |
1429 | .Attr("shared_name: string = ''" ) |
1430 | .SetShapeFn(tensorflow::shape_inference::ScalarShape) |
1431 | .SetIsStateful(); |
1432 | |
1433 | REGISTER_OP("OrderedMapIncompleteSize" ) |
1434 | .Output("size: int32" ) |
1435 | .Attr("capacity: int >= 0 = 0" ) |
1436 | .Attr("memory_limit: int >= 0 = 0" ) |
1437 | .Attr("dtypes: list(type)" ) |
1438 | .Attr("container: string = ''" ) |
1439 | .Attr("shared_name: string = ''" ) |
1440 | .SetShapeFn(tensorflow::shape_inference::ScalarShape) |
1441 | .SetIsStateful(); |
1442 | |
1443 | REGISTER_OP("OrderedMapClear" ) |
1444 | .Attr("capacity: int >= 0 = 0" ) |
1445 | .Attr("memory_limit: int >= 0 = 0" ) |
1446 | .Attr("dtypes: list(type)" ) |
1447 | .Attr("container: string = ''" ) |
1448 | .Attr("shared_name: string = ''" ) |
1449 | .SetShapeFn(tensorflow::shape_inference::NoOutputs) |
1450 | .SetIsStateful(); |
1451 | |
1452 | REGISTER_OP("RecordInput" ) |
1453 | .Output("records: string" ) |
1454 | .Attr("file_pattern: string" ) |
1455 | .Attr("file_random_seed: int = 301" ) |
1456 | .Attr("file_shuffle_shift_ratio: float = 0" ) |
1457 | .Attr("file_buffer_size: int = 10000" ) |
1458 | .Attr("file_parallelism: int = 16" ) |
1459 | .Attr("batch_size: int = 32" ) |
1460 | .Attr("compression_type: string = ''" ) |
1461 | .SetIsStateful() |
1462 | .SetShapeFn(shape_inference::UnknownShape); |
1463 | |
1464 | } // namespace tensorflow |
1465 | |