1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/bounds_check.h" |
17 | #include "tensorflow/core/framework/op_kernel.h" |
18 | #include "tensorflow/core/framework/register_types.h" |
19 | #include "tensorflow/core/framework/tensor.h" |
20 | #include "tensorflow/core/framework/tensor_shape.h" |
21 | #include "tensorflow/core/framework/tensor_util.h" |
22 | #include "tensorflow/core/kernels/random_op.h" |
23 | #include "tensorflow/core/kernels/random_poisson_op.h" |
24 | #include "tensorflow/core/lib/random/random_distributions.h" |
25 | #include "tensorflow/core/platform/logging.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | using CPUDevice = Eigen::ThreadPoolDevice; |
30 | using GPUDevice = Eigen::GpuDevice; |
31 | |
32 | Status GenerateKey(Tensor seed, random::PhiloxRandom::Key* out_key, |
33 | random::PhiloxRandom::ResultType* out_counter) { |
34 | // Grab the two seeds |
35 | uint64 seed0; |
36 | uint64 seed1; |
37 | if (seed.dtype() == DT_INT32) { |
38 | const auto seed_vals = seed.flat<int32>(); |
39 | seed0 = internal::SubtleMustCopy(seed_vals(0)); |
40 | seed1 = internal::SubtleMustCopy(seed_vals(1)); |
41 | } else if (seed.dtype() == DT_INT64) { |
42 | const auto seed_vals = seed.flat<int64_t>(); |
43 | seed0 = internal::SubtleMustCopy(seed_vals(0)); |
44 | seed1 = internal::SubtleMustCopy(seed_vals(1)); |
45 | } else { |
46 | return errors::InvalidArgument("Invalid seed type: " , |
47 | DataTypeString(seed.dtype())); |
48 | } |
49 | |
50 | // Scramble the seeds so that the user doesn't need to worry about which |
51 | // part of the seed needs to be strong. |
52 | (*out_key)[0] = 0x3ec8f720; |
53 | (*out_key)[1] = 0x02461e29; |
54 | (*out_counter)[0] = static_cast<uint32>(seed0); |
55 | (*out_counter)[1] = static_cast<uint32>(seed0 >> 32); |
56 | (*out_counter)[2] = static_cast<uint32>(seed1); |
57 | (*out_counter)[3] = static_cast<uint32>(seed1 >> 32); |
58 | const auto mix = random::PhiloxRandom(*out_counter, *out_key)(); |
59 | (*out_key)[0] = mix[0]; |
60 | (*out_key)[1] = mix[1]; |
61 | (*out_counter)[0] = (*out_counter)[1] = 0; |
62 | (*out_counter)[2] = mix[2]; |
63 | (*out_counter)[3] = mix[3]; |
64 | return OkStatus(); |
65 | } |
66 | |
67 | namespace { |
68 | |
69 | class StatelessRandomOpBase : public OpKernel { |
70 | public: |
71 | explicit StatelessRandomOpBase(OpKernelConstruction* context) |
72 | : OpKernel(context) {} |
73 | |
74 | void Compute(OpKernelContext* context) override { |
75 | // Sanitize input |
76 | const Tensor& shape_t = context->input(0); |
77 | const Tensor& seed_t = context->input(1); |
78 | TensorShape shape; |
79 | OP_REQUIRES_OK(context, tensor::MakeShape(shape_t, &shape)); |
80 | OP_REQUIRES(context, seed_t.dims() == 1 && seed_t.dim_size(0) == 2, |
81 | errors::InvalidArgument("seed must have shape [2], not " , |
82 | seed_t.shape().DebugString())); |
83 | |
84 | // Allocate output |
85 | Tensor* output; |
86 | OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output)); |
87 | if (shape.num_elements() == 0) return; |
88 | |
89 | random::PhiloxRandom::Key key; |
90 | random::PhiloxRandom::ResultType counter; |
91 | OP_REQUIRES_OK(context, GenerateKey(seed_t, &key, &counter)); |
92 | |
93 | // Fill in the random numbers |
94 | Fill(context, random::PhiloxRandom(counter, key), output); |
95 | } |
96 | |
97 | // The part of Compute that depends on device, type, and distribution |
98 | virtual void Fill(OpKernelContext* context, random::PhiloxRandom random, |
99 | Tensor* output) = 0; |
100 | }; |
101 | |
102 | template <typename Device, class Distribution> |
103 | class StatelessRandomOp : public StatelessRandomOpBase { |
104 | public: |
105 | using StatelessRandomOpBase::StatelessRandomOpBase; |
106 | |
107 | void Fill(OpKernelContext* context, random::PhiloxRandom random, |
108 | Tensor* output) override { |
109 | typedef typename Distribution::ResultElementType T; |
110 | auto flat = output->flat<T>(); |
111 | // Reuse the compute kernels from the stateful random ops |
112 | functor::FillPhiloxRandom<Device, Distribution>()( |
113 | context, context->eigen_device<Device>(), /*key=*/nullptr, |
114 | /*counter=*/nullptr, random, flat.data(), flat.size(), Distribution()); |
115 | } |
116 | }; |
117 | |
118 | template <typename Device, typename IntType> |
119 | class StatelessRandomUniformIntOp : public StatelessRandomOpBase { |
120 | public: |
121 | using StatelessRandomOpBase::StatelessRandomOpBase; |
122 | |
123 | void Fill(OpKernelContext* context, random::PhiloxRandom random, |
124 | Tensor* output) override { |
125 | const Tensor& minval = context->input(2); |
126 | const Tensor& maxval = context->input(3); |
127 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(minval.shape()), |
128 | errors::InvalidArgument("minval must be 0-D, got shape " , |
129 | minval.shape().DebugString())); |
130 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(maxval.shape()), |
131 | errors::InvalidArgument("maxval must be 0-D, got shape " , |
132 | maxval.shape().DebugString())); |
133 | |
134 | // Verify that minval < maxval. Note that we'll never reach this point for |
135 | // empty output. Zero impossible things are fine. |
136 | const auto lo = minval.scalar<IntType>()(); |
137 | const auto hi = maxval.scalar<IntType>()(); |
138 | OP_REQUIRES( |
139 | context, lo < hi, |
140 | errors::InvalidArgument("Need minval < maxval, got " , lo, " >= " , hi)); |
141 | |
142 | // Build distribution |
143 | typedef random::UniformDistribution<random::PhiloxRandom, IntType> |
144 | Distribution; |
145 | Distribution dist(lo, hi); |
146 | |
147 | auto flat = output->flat<IntType>(); |
148 | // Reuse the compute kernels from the stateful random ops |
149 | functor::FillPhiloxRandom<Device, Distribution>()( |
150 | context, context->eigen_device<Device>(), /*key=*/nullptr, |
151 | /*counter=*/nullptr, random, flat.data(), flat.size(), dist); |
152 | } |
153 | }; |
154 | |
155 | template <typename Device, typename IntType> |
156 | class StatelessRandomUniformFullIntOp : public StatelessRandomOpBase { |
157 | public: |
158 | using StatelessRandomOpBase::StatelessRandomOpBase; |
159 | |
160 | void Fill(OpKernelContext* context, random::PhiloxRandom random, |
161 | Tensor* output) override { |
162 | // Build distribution |
163 | typedef random::UniformFullIntDistribution<random::PhiloxRandom, IntType> |
164 | Distribution; |
165 | Distribution dist; |
166 | |
167 | auto flat = output->flat<IntType>(); |
168 | // Reuse the compute kernels from the stateful random ops |
169 | functor::FillPhiloxRandom<Device, Distribution>()( |
170 | context, context->eigen_device<Device>(), /*key=*/nullptr, |
171 | /*counter=*/nullptr, random, flat.data(), flat.size(), dist); |
172 | } |
173 | }; |
174 | |
175 | // Samples from one or more Poisson distributions. |
176 | template <typename T, typename U> |
177 | class StatelessRandomPoissonOp : public StatelessRandomOpBase { |
178 | public: |
179 | using StatelessRandomOpBase::StatelessRandomOpBase; |
180 | |
181 | void Fill(OpKernelContext* ctx, random::PhiloxRandom random, |
182 | Tensor* output) override { |
183 | const Tensor& rate_t = ctx->input(2); |
184 | |
185 | TensorShape samples_shape = output->shape(); |
186 | OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(samples_shape, rate_t.shape()), |
187 | errors::InvalidArgument( |
188 | "Shape passed in must end with broadcasted shape." )); |
189 | |
190 | const int64_t num_rate = rate_t.NumElements(); |
191 | const int64_t samples_per_rate = samples_shape.num_elements() / num_rate; |
192 | const auto rate_flat = rate_t.flat<T>().data(); |
193 | auto samples_flat = output->flat<U>().data(); |
194 | |
195 | functor::PoissonFunctor<CPUDevice, T, U>()( |
196 | ctx, ctx->eigen_device<CPUDevice>(), rate_flat, num_rate, |
197 | samples_per_rate, random, samples_flat); |
198 | } |
199 | |
200 | private: |
201 | TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomPoissonOp); |
202 | }; |
203 | |
204 | #define REGISTER(DEVICE, TYPE) \ |
205 | REGISTER_KERNEL_BUILDER( \ |
206 | Name("StatelessRandomUniform") \ |
207 | .Device(DEVICE_##DEVICE) \ |
208 | .HostMemory("shape") \ |
209 | .HostMemory("seed") \ |
210 | .TypeConstraint<TYPE>("dtype"), \ |
211 | StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \ |
212 | random::PhiloxRandom, TYPE> >); \ |
213 | REGISTER_KERNEL_BUILDER( \ |
214 | Name("StatelessRandomNormal") \ |
215 | .Device(DEVICE_##DEVICE) \ |
216 | .HostMemory("shape") \ |
217 | .HostMemory("seed") \ |
218 | .TypeConstraint<TYPE>("dtype"), \ |
219 | StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \ |
220 | random::PhiloxRandom, TYPE> >); \ |
221 | REGISTER_KERNEL_BUILDER( \ |
222 | Name("StatelessTruncatedNormal") \ |
223 | .Device(DEVICE_##DEVICE) \ |
224 | .HostMemory("shape") \ |
225 | .HostMemory("seed") \ |
226 | .TypeConstraint<TYPE>("dtype"), \ |
227 | StatelessRandomOp< \ |
228 | DEVICE##Device, \ |
229 | random::TruncatedNormalDistribution< \ |
230 | random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >) |
231 | |
232 | #define REGISTER_FULL_INT(DEVICE, TYPE) \ |
233 | REGISTER_KERNEL_BUILDER( \ |
234 | Name("StatelessRandomUniformFullInt") \ |
235 | .Device(DEVICE_##DEVICE) \ |
236 | .HostMemory("shape") \ |
237 | .HostMemory("seed") \ |
238 | .TypeConstraint<TYPE>("dtype"), \ |
239 | StatelessRandomUniformFullIntOp<DEVICE##Device, TYPE>) |
240 | |
241 | #define REGISTER_INT(DEVICE, TYPE) \ |
242 | REGISTER_FULL_INT(DEVICE, TYPE); \ |
243 | REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \ |
244 | .Device(DEVICE_##DEVICE) \ |
245 | .HostMemory("shape") \ |
246 | .HostMemory("seed") \ |
247 | .HostMemory("minval") \ |
248 | .HostMemory("maxval") \ |
249 | .TypeConstraint<TYPE>("dtype"), \ |
250 | StatelessRandomUniformIntOp<DEVICE##Device, TYPE>) |
251 | |
252 | #define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE) |
253 | #define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE) |
254 | #define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE) |
255 | #define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE) |
256 | #define REGISTER_FULL_INT_CPU(TYPE) REGISTER_FULL_INT(CPU, TYPE) |
257 | #define REGISTER_FULL_INT_GPU(TYPE) REGISTER_FULL_INT(GPU, TYPE) |
258 | |
259 | TF_CALL_half(REGISTER_CPU); |
260 | TF_CALL_bfloat16(REGISTER_CPU); |
261 | TF_CALL_float(REGISTER_CPU); |
262 | TF_CALL_double(REGISTER_CPU); |
263 | TF_CALL_int32(REGISTER_INT_CPU); |
264 | TF_CALL_int64(REGISTER_INT_CPU); |
265 | TF_CALL_uint32(REGISTER_FULL_INT_CPU); |
266 | TF_CALL_uint64(REGISTER_FULL_INT_CPU); |
267 | |
268 | #define REGISTER_POISSON(RATE_TYPE, OUT_TYPE) \ |
269 | REGISTER_KERNEL_BUILDER(Name("StatelessRandomPoisson") \ |
270 | .Device(DEVICE_CPU) \ |
271 | .HostMemory("shape") \ |
272 | .HostMemory("seed") \ |
273 | .HostMemory("lam") \ |
274 | .TypeConstraint<RATE_TYPE>("Rtype") \ |
275 | .TypeConstraint<OUT_TYPE>("dtype"), \ |
276 | StatelessRandomPoissonOp<RATE_TYPE, OUT_TYPE>) |
277 | |
278 | #define REGISTER_ALL_POISSON(RATE_TYPE) \ |
279 | REGISTER_POISSON(RATE_TYPE, Eigen::half); \ |
280 | REGISTER_POISSON(RATE_TYPE, float); \ |
281 | REGISTER_POISSON(RATE_TYPE, double); \ |
282 | REGISTER_POISSON(RATE_TYPE, int32); \ |
283 | REGISTER_POISSON(RATE_TYPE, int64_t) |
284 | |
285 | TF_CALL_half(REGISTER_ALL_POISSON); |
286 | TF_CALL_float(REGISTER_ALL_POISSON); |
287 | TF_CALL_double(REGISTER_ALL_POISSON); |
288 | TF_CALL_int32(REGISTER_ALL_POISSON); |
289 | TF_CALL_int64(REGISTER_ALL_POISSON); |
290 | |
291 | #undef REGISTER_ALL_POISSON |
292 | #undef REGISTER_POISSON |
293 | |
294 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
295 | |
296 | TF_CALL_half(REGISTER_GPU); |
297 | TF_CALL_float(REGISTER_GPU); |
298 | TF_CALL_double(REGISTER_GPU); |
299 | TF_CALL_int32(REGISTER_INT_GPU); |
300 | TF_CALL_int64(REGISTER_INT_GPU); |
301 | TF_CALL_uint32(REGISTER_FULL_INT_GPU); |
302 | TF_CALL_uint64(REGISTER_FULL_INT_GPU); |
303 | |
304 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
305 | |
306 | #undef REGISTER |
307 | #undef REGISTER_INT |
308 | #undef REGISTER_CPU |
309 | #undef REGISTER_GPU |
310 | #undef REGISTER_INT_CPU |
311 | #undef REGISTER_INT_GPU |
312 | #undef REGISTER_FULL_INT_CPU |
313 | #undef REGISTER_FULL_INT_GPU |
314 | |
315 | } // namespace |
316 | |
317 | } // namespace tensorflow |
318 | |