1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_ |
18 | |
19 | #define EIGEN_USE_THREADS |
20 | |
21 | #include <algorithm> |
22 | #include <cmath> |
23 | #include <memory> |
24 | |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/framework/register_types.h" |
27 | #include "tensorflow/core/framework/tensor.h" |
28 | #include "tensorflow/core/framework/tensor_shape.h" |
29 | #include "tensorflow/core/kernels/random_op.h" |
30 | #include "tensorflow/core/kernels/random_ops_util.h" |
31 | #include "tensorflow/core/lib/hash/crc32c.h" |
32 | #include "tensorflow/core/lib/random/random_distributions.h" |
33 | #include "tensorflow/core/lib/random/simple_philox.h" |
34 | #include "tensorflow/core/platform/logging.h" |
35 | #include "tensorflow/core/util/guarded_philox_random.h" |
36 | #include "tensorflow/core/util/work_sharder.h" |
37 | |
38 | #if EIGEN_COMP_GNUC && __cplusplus > 199711L |
39 | #define DISABLE_FLOAT_EQUALITY_WARNING \ |
40 | _Pragma("GCC diagnostic push") \ |
41 | _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"") |
42 | #define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop") |
43 | #else |
44 | #define DISABLE_FLOAT_EQUALITY_WARNING |
45 | #define ENABLE_FLOAT_EQUALITY_WARNING |
46 | #endif |
47 | |
48 | namespace tensorflow { |
49 | |
50 | typedef Eigen::ThreadPoolDevice CPUDevice; |
51 | typedef Eigen::GpuDevice GPUDevice; |
52 | |
53 | namespace functor { |
54 | using random::PhiloxRandom; |
55 | using random::SingleSampleAdapter; |
56 | |
57 | // The default implementation of the functor, which should never be invoked |
58 | // But we still need to provide implementation for now for the linker to work, |
59 | // since we do not support all the distributions yet. |
60 | template <typename Device, class Distribution> |
61 | struct FillPhiloxRandom { |
62 | typedef typename Distribution::ResultElementType T; |
63 | void operator()(OpKernelContext* ctx, const Device&, const uint64* key, |
64 | const uint64* counter, random::PhiloxRandom gen, T* data, |
65 | int64_t size, Distribution dist) { |
66 | OP_REQUIRES( |
67 | ctx, false, |
68 | errors::Internal( |
69 | "Default `FillPhiloxRandom` implementation should not be executed. " |
70 | "The cause of this error is probably that `FillPhiloxRandom` does " |
71 | "not support this device or random distribution yet." )); |
72 | } |
73 | }; |
74 | |
75 | // A class to fill a specified range of random groups |
76 | template <class Distribution, bool VariableSamplesPerOutput> |
77 | struct FillPhiloxRandomTask; |
78 | |
79 | // Specialization for distribution that takes a fixed number of samples for |
80 | // each output. |
81 | template <class Distribution> |
82 | struct FillPhiloxRandomTask<Distribution, false> { |
83 | typedef typename Distribution::ResultElementType T; |
84 | static void Run(random::PhiloxRandom gen, T* data, int64_t size, |
85 | int64_t start_group, int64_t limit_group, Distribution dist) { |
86 | const int kGroupSize = Distribution::kResultElementCount; |
87 | |
88 | gen.Skip(start_group); |
89 | int64_t offset = start_group * kGroupSize; |
90 | |
91 | // First fill all the full-size groups |
92 | int64_t limit_group_full = std::min(limit_group, size / kGroupSize); |
93 | for (int64_t index = start_group; index < limit_group_full; ++index) { |
94 | auto samples = dist(&gen); |
95 | std::copy(&samples[0], &samples[0] + kGroupSize, data + offset); |
96 | offset += kGroupSize; |
97 | } |
98 | |
99 | // If there are any remaining elements that need to be filled, process them |
100 | if (limit_group_full < limit_group) { |
101 | int64_t remaining_size = size - limit_group_full * kGroupSize; |
102 | auto samples = dist(&gen); |
103 | std::copy(&samples[0], &samples[0] + remaining_size, data + offset); |
104 | } |
105 | } |
106 | }; |
107 | |
108 | // Specialization for distribution that takes a variable number of samples for |
109 | // each output. This will be slower due to the generality. |
110 | template <class Distribution> |
111 | struct FillPhiloxRandomTask<Distribution, true> { |
112 | typedef typename Distribution::ResultElementType T; |
113 | static constexpr int64_t kReservedSamplesPerOutput = 256; |
114 | |
115 | static void Run(random::PhiloxRandom base_gen, T* data, int64_t size, |
116 | int64_t start_group, int64_t limit_group, Distribution dist) { |
117 | const int kGroupSize = Distribution::kResultElementCount; |
118 | |
119 | static const int kGeneratorSkipPerOutputGroup = |
120 | kGroupSize * kReservedSamplesPerOutput / |
121 | PhiloxRandom::kResultElementCount; |
122 | |
123 | int64_t offset = start_group * kGroupSize; |
124 | |
125 | // First fill all the full-size groups |
126 | int64_t limit_group_full = std::min(limit_group, size / kGroupSize); |
127 | int64_t group_index; |
128 | for (group_index = start_group; group_index < limit_group_full; |
129 | ++group_index) { |
130 | // Reset the generator to the beginning of the output group region |
131 | // This is necessary if we want the results to be independent of order |
132 | // of work |
133 | PhiloxRandom gen = base_gen; |
134 | gen.Skip(group_index * kGeneratorSkipPerOutputGroup); |
135 | SingleSampleAdapter<PhiloxRandom> single_samples(&gen); |
136 | |
137 | auto samples = dist(&single_samples); |
138 | std::copy(&samples[0], &samples[0] + kGroupSize, data + offset); |
139 | offset += kGroupSize; |
140 | } |
141 | |
142 | // If there are any remaining elements that need to be filled, process them |
143 | if (limit_group_full < limit_group) { |
144 | PhiloxRandom gen = base_gen; |
145 | gen.Skip(group_index * kGeneratorSkipPerOutputGroup); |
146 | SingleSampleAdapter<PhiloxRandom> single_samples(&gen); |
147 | |
148 | int64_t remaining_size = size - limit_group_full * kGroupSize; |
149 | auto samples = dist(&single_samples); |
150 | std::copy(&samples[0], &samples[0] + remaining_size, data + offset); |
151 | } |
152 | } |
153 | }; |
154 | |
155 | // Partial specialization for CPU to fill the entire region with randoms |
156 | // It splits the work into several tasks and run them in parallel |
157 | template <class Distribution> |
158 | void FillPhiloxRandom<CPUDevice, Distribution>::operator()( |
159 | OpKernelContext* ctx, const CPUDevice&, const uint64* key, |
160 | const uint64* counter, random::PhiloxRandom gen, |
161 | typename Distribution::ResultElementType* data, int64_t size, |
162 | Distribution dist) { |
163 | const int kGroupSize = Distribution::kResultElementCount; |
164 | |
165 | auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); |
166 | |
167 | int64_t total_group_count = (size + kGroupSize - 1) / kGroupSize; |
168 | |
169 | const int kGroupCost = |
170 | random::PhiloxRandom::kResultElementCount * |
171 | (random::PhiloxRandom::kElementCost + Distribution::kElementCost); |
172 | |
173 | if (key != nullptr && counter != nullptr) { |
174 | gen = GetPhiloxRandomFromCounterKeyMem(counter, key); |
175 | } |
176 | |
177 | Shard(worker_threads.num_threads, worker_threads.workers, total_group_count, |
178 | kGroupCost, |
179 | [&gen, data, size, dist](int64_t start_group, int64_t limit_group) { |
180 | FillPhiloxRandomTask< |
181 | Distribution, |
182 | Distribution::kVariableSamplesPerOutput>::Run(gen, data, size, |
183 | start_group, |
184 | limit_group, dist); |
185 | }); |
186 | } |
187 | |
188 | } // namespace functor |
189 | |
190 | |
191 | } // end namespace tensorflow |
192 | |
193 | #endif // TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_ |
194 | |