1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Common utilities for random shuffling.
17
18#ifndef TENSORFLOW_CORE_KERNELS_SHUFFLE_COMMON_H_
19#define TENSORFLOW_CORE_KERNELS_SHUFFLE_COMMON_H_
20
21#include <algorithm>
22#include <functional>
23
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/tensor_util.h"
26#include "tensorflow/core/lib/random/philox_random.h"
27#include "tensorflow/core/lib/random/random_distributions.h"
28
29namespace tensorflow {
30
31// TODO(irving): If performance is critical, generate output directly instead
32// of an in-place shuffle using a pseudorandom permutation like
33//
34// https://github.com/otherlab/geode/blob/master/geode/random/permute.cpp
35//
36// This is probably also the right thing if we want a GPU version of shuffling.
37
38// We use our own version of std::random_shuffle to guarantee that exactly
39// size - 1 samples are used.
40template <class Iter, class Random>
41static inline void ShuffleRange(Iter first, Iter last, Random& uniform) {
42 if (first == last) return;
43 const auto stop = last - 1;
44 for (auto i = first; i != stop; ++i) {
45 using std::iter_swap;
46 iter_swap(i, i + uniform(last - i));
47 }
48}
49
50template <class IntT, class InT, class OutT, class Random>
51static void IndexedShuffle(const int64_t size, const InT& input_mat,
52 OutT output_mat, Random& uniform) {
53 std::vector<IntT> permutation(size);
54 for (IntT i = 0; i < size; i++) {
55 permutation[i] = i;
56 }
57 ShuffleRange(permutation.begin(), permutation.end(), uniform);
58 for (IntT i = 0; i < size; i++) {
59 output_mat.template chip<0>(i) = input_mat.template chip<0>(permutation[i]);
60 }
61}
62
63template <typename T>
64Status RandomShuffle(OpKernelContext* context, const Tensor& input,
65 int output_idx,
66 std::function<random::PhiloxRandom(int64_t)> get_rng) {
67 if (input.NumElements() <= 1 || input.dim_size(0) <= 1) {
68 // No shuffling is required, so copy input directly to output
69 context->set_output(output_idx, input);
70 } else {
71 // Reserve enough random samples for shuffling
72 const int64_t size = input.dim_size(0);
73 const int64_t samples = size - 1;
74 auto rng = get_rng(samples);
75 random::SingleSampleAdapter<random::PhiloxRandom> single(&rng);
76 const auto uniform = [&single](uint32 n) { return single() % n; };
77
78 if (input.dims() == 1) {
79 // For 1D data, copy and then shuffle in place
80 context->set_output(output_idx, tensor::DeepCopy(input));
81 auto vec = context->mutable_output(output_idx)->vec<T>();
82 ShuffleRange(vec.data(), vec.data() + size, uniform);
83 } else {
84 // For >= 2D, shuffle indices and then copy across
85 Tensor* output = nullptr;
86 TF_RETURN_IF_ERROR(
87 context->allocate_output(output_idx, input.shape(), &output));
88 const auto input_mat = input.flat_outer_dims<T>();
89 auto output_mat = output->flat_outer_dims<T>();
90 if (size < kint32max) {
91 IndexedShuffle<int32>(size, input_mat, output_mat, uniform);
92 } else {
93 IndexedShuffle<int64_t>(size, input_mat, output_mat, uniform);
94 }
95 }
96 }
97 return OkStatus();
98}
99
100} // namespace tensorflow
101
102#endif // TENSORFLOW_CORE_KERNELS_SHUFFLE_COMMON_H_
103