1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Common utilities for random shuffling. |
17 | |
18 | #ifndef TENSORFLOW_CORE_KERNELS_SHUFFLE_COMMON_H_ |
19 | #define TENSORFLOW_CORE_KERNELS_SHUFFLE_COMMON_H_ |
20 | |
21 | #include <algorithm> |
22 | #include <functional> |
23 | |
24 | #include "tensorflow/core/framework/op_kernel.h" |
25 | #include "tensorflow/core/framework/tensor_util.h" |
26 | #include "tensorflow/core/lib/random/philox_random.h" |
27 | #include "tensorflow/core/lib/random/random_distributions.h" |
28 | |
29 | namespace tensorflow { |
30 | |
31 | // TODO(irving): If performance is critical, generate output directly instead |
32 | // of an in-place shuffle using a pseudorandom permutation like |
33 | // |
34 | // https://github.com/otherlab/geode/blob/master/geode/random/permute.cpp |
35 | // |
36 | // This is probably also the right thing if we want a GPU version of shuffling. |
37 | |
38 | // We use our own version of std::random_shuffle to guarantee that exactly |
39 | // size - 1 samples are used. |
40 | template <class Iter, class Random> |
41 | static inline void ShuffleRange(Iter first, Iter last, Random& uniform) { |
42 | if (first == last) return; |
43 | const auto stop = last - 1; |
44 | for (auto i = first; i != stop; ++i) { |
45 | using std::iter_swap; |
46 | iter_swap(i, i + uniform(last - i)); |
47 | } |
48 | } |
49 | |
50 | template <class IntT, class InT, class OutT, class Random> |
51 | static void IndexedShuffle(const int64_t size, const InT& input_mat, |
52 | OutT output_mat, Random& uniform) { |
53 | std::vector<IntT> permutation(size); |
54 | for (IntT i = 0; i < size; i++) { |
55 | permutation[i] = i; |
56 | } |
57 | ShuffleRange(permutation.begin(), permutation.end(), uniform); |
58 | for (IntT i = 0; i < size; i++) { |
59 | output_mat.template chip<0>(i) = input_mat.template chip<0>(permutation[i]); |
60 | } |
61 | } |
62 | |
63 | template <typename T> |
64 | Status RandomShuffle(OpKernelContext* context, const Tensor& input, |
65 | int output_idx, |
66 | std::function<random::PhiloxRandom(int64_t)> get_rng) { |
67 | if (input.NumElements() <= 1 || input.dim_size(0) <= 1) { |
68 | // No shuffling is required, so copy input directly to output |
69 | context->set_output(output_idx, input); |
70 | } else { |
71 | // Reserve enough random samples for shuffling |
72 | const int64_t size = input.dim_size(0); |
73 | const int64_t samples = size - 1; |
74 | auto rng = get_rng(samples); |
75 | random::SingleSampleAdapter<random::PhiloxRandom> single(&rng); |
76 | const auto uniform = [&single](uint32 n) { return single() % n; }; |
77 | |
78 | if (input.dims() == 1) { |
79 | // For 1D data, copy and then shuffle in place |
80 | context->set_output(output_idx, tensor::DeepCopy(input)); |
81 | auto vec = context->mutable_output(output_idx)->vec<T>(); |
82 | ShuffleRange(vec.data(), vec.data() + size, uniform); |
83 | } else { |
84 | // For >= 2D, shuffle indices and then copy across |
85 | Tensor* output = nullptr; |
86 | TF_RETURN_IF_ERROR( |
87 | context->allocate_output(output_idx, input.shape(), &output)); |
88 | const auto input_mat = input.flat_outer_dims<T>(); |
89 | auto output_mat = output->flat_outer_dims<T>(); |
90 | if (size < kint32max) { |
91 | IndexedShuffle<int32>(size, input_mat, output_mat, uniform); |
92 | } else { |
93 | IndexedShuffle<int64_t>(size, input_mat, output_mat, uniform); |
94 | } |
95 | } |
96 | } |
97 | return OkStatus(); |
98 | } |
99 | |
100 | } // namespace tensorflow |
101 | |
102 | #endif // TENSORFLOW_CORE_KERNELS_SHUFFLE_COMMON_H_ |
103 | |