1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17
18#include "tensorflow/core/framework/common_shape_fns.h"
19#include "tensorflow/core/framework/op.h"
20#include "tensorflow/core/framework/shape_inference.h"
21#include "tensorflow/core/platform/errors.h"
22
23namespace tensorflow {
24
25using shape_inference::InferenceContext;
26using shape_inference::ShapeHandle;
27
28namespace {
29
30static Status StatelessRandomPermuteShape(InferenceContext* c) {
31 ShapeHandle index_shape, seed_shape, max_index_shape, rounds_shape;
32
33 // Basic constraints but unknown ranks will not raise errors here.
34 // index, seed and max_index can be scalars or vectors (when batching).
35 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &index_shape));
36 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &seed_shape));
37 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 2, &seed_shape));
38 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &max_index_shape));
39
40 // Figure out if the output is a scalar or tensor.
41 const int32 index_rank = c->Rank(index_shape);
42 const int32 seed_rank = c->Rank(seed_shape);
43 const int32 max_index_rank = c->Rank(max_index_shape);
44
45 // Check that last dimension of seed is 3.
46 if (seed_rank == 1 && c->Value(c->Dim(seed_shape, 0)) != 3) {
47 return errors::InvalidArgument("Seed must have shape [3] but got [",
48 c->Value(c->Dim(seed_shape, 0)), "].");
49 }
50 if (seed_rank == 2 && c->Value(c->Dim(seed_shape, 1)) != 3) {
51 return errors::InvalidArgument("Seed must have shape [n, 3] but got [",
52 c->Value(c->Dim(seed_shape, 0)), ", ",
53 c->Value(c->Dim(seed_shape, 1)), "].");
54 }
55
56 // If all inputs are scalars the output is a scalar.
57 const bool output_is_scalar =
58 (index_rank == 0 && seed_rank == 1 && max_index_rank == 0);
59 if (output_is_scalar) {
60 c->set_output(0, c->Scalar());
61 return OkStatus();
62 }
63
64 if (!c->FullyDefined(index_shape) || !c->FullyDefined(seed_shape) ||
65 !c->FullyDefined(max_index_shape)) {
66 const bool output_is_vector =
67 (index_rank == 1 || seed_rank == 2 || max_index_rank == 1);
68 if (output_is_vector) {
69 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
70 }
71 return OkStatus();
72 }
73
74 // Shape is fully defined and the output is a vector.
75 const int64_t num_indices = index_rank ? c->Value(c->Dim(index_shape, 0)) : 1;
76 const int64_t num_seeds =
77 seed_rank == 2 ? c->Value(c->Dim(seed_shape, 0)) : 1;
78 const int64_t num_max_indices =
79 max_index_rank ? c->Value(c->Dim(max_index_shape, 0)) : 1;
80 const int64_t num_outputs =
81 std::max(std::max(num_indices, num_seeds), num_max_indices);
82 if (num_indices != 1 && num_indices != num_outputs) {
83 return errors::InvalidArgument("Index has shape [", num_indices,
84 "] but must have shape [", num_outputs,
85 "].");
86 }
87 if (num_seeds != 1 && num_seeds != num_outputs) {
88 return errors::InvalidArgument("Seed has shape [", num_seeds,
89 "3, ] but must have shape [", num_outputs,
90 ", 3].");
91 }
92 if (num_max_indices != 1 && num_max_indices != num_outputs) {
93 return errors::InvalidArgument("Max index has shape [", num_max_indices,
94 "] but must have shape [", num_outputs,
95 "].");
96 }
97 c->set_output(0, c->Vector(num_outputs));
98 return OkStatus();
99}
100
101REGISTER_OP("RandomIndexShuffle")
102 .Input("index: dtype")
103 .Input("seed: Tseed")
104 .Input("max_index: dtype")
105 .Output("output: dtype")
106 .Attr("rounds: int = 4")
107 .Attr("dtype: {int32, uint32, int64, uint64}")
108 .Attr("Tseed: {int32, uint32, int64, uint64}")
109 .SetShapeFn(StatelessRandomPermuteShape);
110
111} // namespace
112} // namespace tensorflow
113