1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | |
18 | #include "tensorflow/core/framework/common_shape_fns.h" |
19 | #include "tensorflow/core/framework/op.h" |
20 | #include "tensorflow/core/framework/shape_inference.h" |
21 | #include "tensorflow/core/platform/errors.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | using shape_inference::InferenceContext; |
26 | using shape_inference::ShapeHandle; |
27 | |
28 | namespace { |
29 | |
30 | static Status StatelessRandomPermuteShape(InferenceContext* c) { |
31 | ShapeHandle index_shape, seed_shape, max_index_shape, rounds_shape; |
32 | |
33 | // Basic constraints but unknown ranks will not raise errors here. |
34 | // index, seed and max_index can be scalars or vectors (when batching). |
35 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &index_shape)); |
36 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &seed_shape)); |
37 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 2, &seed_shape)); |
38 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &max_index_shape)); |
39 | |
40 | // Figure out if the output is a scalar or tensor. |
41 | const int32 index_rank = c->Rank(index_shape); |
42 | const int32 seed_rank = c->Rank(seed_shape); |
43 | const int32 max_index_rank = c->Rank(max_index_shape); |
44 | |
45 | // Check that last dimension of seed is 3. |
46 | if (seed_rank == 1 && c->Value(c->Dim(seed_shape, 0)) != 3) { |
47 | return errors::InvalidArgument("Seed must have shape [3] but got [" , |
48 | c->Value(c->Dim(seed_shape, 0)), "]." ); |
49 | } |
50 | if (seed_rank == 2 && c->Value(c->Dim(seed_shape, 1)) != 3) { |
51 | return errors::InvalidArgument("Seed must have shape [n, 3] but got [" , |
52 | c->Value(c->Dim(seed_shape, 0)), ", " , |
53 | c->Value(c->Dim(seed_shape, 1)), "]." ); |
54 | } |
55 | |
56 | // If all inputs are scalars the output is a scalar. |
57 | const bool output_is_scalar = |
58 | (index_rank == 0 && seed_rank == 1 && max_index_rank == 0); |
59 | if (output_is_scalar) { |
60 | c->set_output(0, c->Scalar()); |
61 | return OkStatus(); |
62 | } |
63 | |
64 | if (!c->FullyDefined(index_shape) || !c->FullyDefined(seed_shape) || |
65 | !c->FullyDefined(max_index_shape)) { |
66 | const bool output_is_vector = |
67 | (index_rank == 1 || seed_rank == 2 || max_index_rank == 1); |
68 | if (output_is_vector) { |
69 | c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); |
70 | } |
71 | return OkStatus(); |
72 | } |
73 | |
74 | // Shape is fully defined and the output is a vector. |
75 | const int64_t num_indices = index_rank ? c->Value(c->Dim(index_shape, 0)) : 1; |
76 | const int64_t num_seeds = |
77 | seed_rank == 2 ? c->Value(c->Dim(seed_shape, 0)) : 1; |
78 | const int64_t num_max_indices = |
79 | max_index_rank ? c->Value(c->Dim(max_index_shape, 0)) : 1; |
80 | const int64_t num_outputs = |
81 | std::max(std::max(num_indices, num_seeds), num_max_indices); |
82 | if (num_indices != 1 && num_indices != num_outputs) { |
83 | return errors::InvalidArgument("Index has shape [" , num_indices, |
84 | "] but must have shape [" , num_outputs, |
85 | "]." ); |
86 | } |
87 | if (num_seeds != 1 && num_seeds != num_outputs) { |
88 | return errors::InvalidArgument("Seed has shape [" , num_seeds, |
89 | "3, ] but must have shape [" , num_outputs, |
90 | ", 3]." ); |
91 | } |
92 | if (num_max_indices != 1 && num_max_indices != num_outputs) { |
93 | return errors::InvalidArgument("Max index has shape [" , num_max_indices, |
94 | "] but must have shape [" , num_outputs, |
95 | "]." ); |
96 | } |
97 | c->set_output(0, c->Vector(num_outputs)); |
98 | return OkStatus(); |
99 | } |
100 | |
101 | REGISTER_OP("RandomIndexShuffle" ) |
102 | .Input("index: dtype" ) |
103 | .Input("seed: Tseed" ) |
104 | .Input("max_index: dtype" ) |
105 | .Output("output: dtype" ) |
106 | .Attr("rounds: int = 4" ) |
107 | .Attr("dtype: {int32, uint32, int64, uint64}" ) |
108 | .Attr("Tseed: {int32, uint32, int64, uint64}" ) |
109 | .SetShapeFn(StatelessRandomPermuteShape); |
110 | |
111 | } // namespace |
112 | } // namespace tensorflow |
113 | |