1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <array>
18#include <memory>
19
20#include "absl/strings/str_format.h"
21#include "absl/strings/string_view.h"
22#include "tensorflow/core/framework/op.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/op_requires.h"
25#include "tensorflow/core/framework/register_types.h"
26#include "tensorflow/core/framework/shape_inference.h"
27#include "tensorflow/core/framework/tensor_shape.h"
28#include "tensorflow/core/framework/tensor_util.h"
29#include "tensorflow/core/framework/types.pb.h"
30#include "tensorflow/core/kernels/random_index_shuffle.h"
31#include "tensorflow/core/platform/errors.h"
32#include "tensorflow/core/platform/types.h"
33#include "tensorflow/core/profiler/lib/traceme.h"
34#include "tensorflow/core/protobuf/error_codes.pb.h"
35
36namespace tensorflow {
37namespace {
38
39constexpr absl::string_view kRounds = "rounds";
40
41template <typename DType>
42std::array<uint32_t, 3> CastSeedFrom(const Tensor& seed_t, const int row) {
43 const auto seed_vals = seed_t.flat<DType>();
44 return {static_cast<uint32_t>(seed_vals(3 * row)),
45 static_cast<uint32_t>(seed_vals(3 * row + 1)),
46 static_cast<uint32_t>(seed_vals(3 * row + 2))};
47}
48
49Status GetSeed(const Tensor& seed_t, const int row,
50 std::array<uint32_t, 3>* seed) {
51 if (seed_t.dtype() == DT_INT32) {
52 *seed = CastSeedFrom<int32_t>(seed_t, row);
53 } else if (seed_t.dtype() == DT_UINT32) {
54 *seed = CastSeedFrom<uint32_t>(seed_t, row);
55 } else if (seed_t.dtype() == DT_INT64) {
56 *seed = CastSeedFrom<int64_t>(seed_t, row);
57 } else if (seed_t.dtype() == DT_UINT64) {
58 *seed = CastSeedFrom<uint64_t>(seed_t, row);
59 } else {
60 return errors::InvalidArgument("Invalid seed type: ",
61 DataTypeString(seed_t.dtype()));
62 }
63 return OkStatus();
64}
65
66template <typename IntType>
67class RandomIndexShuffleOp : public OpKernel {
68 public:
69 explicit RandomIndexShuffleOp(OpKernelConstruction* context)
70 : OpKernel(context) {
71 OP_REQUIRES_OK(context, context->GetAttr(kRounds, &rounds_));
72 }
73
74 void Compute(OpKernelContext* context) override {
75 const Tensor& index_t = context->input(0);
76 const Tensor& seed_t = context->input(1);
77 const Tensor& max_index_t = context->input(2);
78
79 const bool all_scalar =
80 index_t.dims() == 0 && seed_t.dims() == 1 && max_index_t.dims() == 0;
81 const int64_t num_outputs =
82 std::max(std::max(index_t.NumElements(), max_index_t.NumElements()),
83 seed_t.NumElements() / 3);
84
85 // Check shapes.
86 OP_REQUIRES(context,
87 index_t.dims() == 0 ||
88 (index_t.dims() == 1 && index_t.dim_size(0) == num_outputs),
89 errors::InvalidArgument("Index bust be a scalar or vector."));
90 OP_REQUIRES(context,
91 (seed_t.dims() == 1 && seed_t.dim_size(0) == 3) ||
92 (seed_t.dims() == 2 && seed_t.dim_size(0) == num_outputs &&
93 seed_t.dim_size(1) == 3),
94 errors::InvalidArgument(absl::StrFormat(
95 "Seed must be a vector of size [3] "
96 "or a matrix of size [%d, 3] but got %s.",
97 num_outputs, seed_t.shape().DebugString())));
98 OP_REQUIRES(
99 context,
100 max_index_t.dims() == 0 ||
101 (max_index_t.dims() == 1 && max_index_t.dim_size(0) == num_outputs),
102 errors::InvalidArgument(
103 absl::StrFormat("Maxval must be a scalar or a vector of "
104 "the same size as index but got %s",
105 max_index_t.shape().DebugString())));
106
107 // Create output tensor.
108 Tensor* new_index_t;
109 if (all_scalar) {
110 OP_REQUIRES_OK(
111 context, context->allocate_output(0, index_t.shape(), &new_index_t));
112 } else {
113 TensorShape new_index_shape({num_outputs});
114 OP_REQUIRES_OK(
115 context, context->allocate_output(0, new_index_shape, &new_index_t));
116 }
117
118 for (int64_t i = 0; i < num_outputs; ++i) {
119 const auto index =
120 static_cast<uint64_t>(index_t.dims() ? index_t.vec<IntType>()(i)
121 : index_t.scalar<IntType>()());
122 const auto max_index = static_cast<uint64_t>(
123 max_index_t.dims() ? max_index_t.vec<IntType>()(i)
124 : max_index_t.scalar<IntType>()());
125 std::array<uint32_t, 3> seed;
126 OP_REQUIRES_OK(context,
127 GetSeed(seed_t, seed_t.dims() == 1 ? 0 : i, &seed));
128 const auto new_index =
129 tensorflow::random::index_shuffle(index, seed, max_index, rounds_);
130 new_index_t->flat<IntType>()(i) = static_cast<IntType>(new_index);
131 }
132 }
133
134 private:
135 int32_t rounds_; // Number of rounds for the block cipher.
136
137 TF_DISALLOW_COPY_AND_ASSIGN(RandomIndexShuffleOp);
138};
139
140#define REGISTER(TYPE) \
141 REGISTER_KERNEL_BUILDER(Name("RandomIndexShuffle") \
142 .Device(DEVICE_CPU) \
143 .TypeConstraint<TYPE>("dtype"), \
144 RandomIndexShuffleOp<TYPE>);
145
146TF_CALL_int32(REGISTER);
147TF_CALL_int64(REGISTER);
148TF_CALL_uint32(REGISTER);
149TF_CALL_uint64(REGISTER);
150
151} // namespace
152} // namespace tensorflow
153