1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/bounds_check.h"
17#include "tensorflow/core/framework/op_kernel.h"
18#include "tensorflow/core/framework/register_types.h"
19#include "tensorflow/core/framework/tensor.h"
20#include "tensorflow/core/framework/tensor_shape.h"
21#include "tensorflow/core/framework/tensor_util.h"
22#include "tensorflow/core/kernels/random_op.h"
23#include "tensorflow/core/kernels/random_poisson_op.h"
24#include "tensorflow/core/lib/random/random_distributions.h"
25#include "tensorflow/core/platform/logging.h"
26
27namespace tensorflow {
28
29using CPUDevice = Eigen::ThreadPoolDevice;
30using GPUDevice = Eigen::GpuDevice;
31
32Status GenerateKey(Tensor seed, random::PhiloxRandom::Key* out_key,
33 random::PhiloxRandom::ResultType* out_counter) {
34 // Grab the two seeds
35 uint64 seed0;
36 uint64 seed1;
37 if (seed.dtype() == DT_INT32) {
38 const auto seed_vals = seed.flat<int32>();
39 seed0 = internal::SubtleMustCopy(seed_vals(0));
40 seed1 = internal::SubtleMustCopy(seed_vals(1));
41 } else if (seed.dtype() == DT_INT64) {
42 const auto seed_vals = seed.flat<int64_t>();
43 seed0 = internal::SubtleMustCopy(seed_vals(0));
44 seed1 = internal::SubtleMustCopy(seed_vals(1));
45 } else {
46 return errors::InvalidArgument("Invalid seed type: ",
47 DataTypeString(seed.dtype()));
48 }
49
50 // Scramble the seeds so that the user doesn't need to worry about which
51 // part of the seed needs to be strong.
52 (*out_key)[0] = 0x3ec8f720;
53 (*out_key)[1] = 0x02461e29;
54 (*out_counter)[0] = static_cast<uint32>(seed0);
55 (*out_counter)[1] = static_cast<uint32>(seed0 >> 32);
56 (*out_counter)[2] = static_cast<uint32>(seed1);
57 (*out_counter)[3] = static_cast<uint32>(seed1 >> 32);
58 const auto mix = random::PhiloxRandom(*out_counter, *out_key)();
59 (*out_key)[0] = mix[0];
60 (*out_key)[1] = mix[1];
61 (*out_counter)[0] = (*out_counter)[1] = 0;
62 (*out_counter)[2] = mix[2];
63 (*out_counter)[3] = mix[3];
64 return OkStatus();
65}
66
67namespace {
68
69class StatelessRandomOpBase : public OpKernel {
70 public:
71 explicit StatelessRandomOpBase(OpKernelConstruction* context)
72 : OpKernel(context) {}
73
74 void Compute(OpKernelContext* context) override {
75 // Sanitize input
76 const Tensor& shape_t = context->input(0);
77 const Tensor& seed_t = context->input(1);
78 TensorShape shape;
79 OP_REQUIRES_OK(context, tensor::MakeShape(shape_t, &shape));
80 OP_REQUIRES(context, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
81 errors::InvalidArgument("seed must have shape [2], not ",
82 seed_t.shape().DebugString()));
83
84 // Allocate output
85 Tensor* output;
86 OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output));
87 if (shape.num_elements() == 0) return;
88
89 random::PhiloxRandom::Key key;
90 random::PhiloxRandom::ResultType counter;
91 OP_REQUIRES_OK(context, GenerateKey(seed_t, &key, &counter));
92
93 // Fill in the random numbers
94 Fill(context, random::PhiloxRandom(counter, key), output);
95 }
96
97 // The part of Compute that depends on device, type, and distribution
98 virtual void Fill(OpKernelContext* context, random::PhiloxRandom random,
99 Tensor* output) = 0;
100};
101
102template <typename Device, class Distribution>
103class StatelessRandomOp : public StatelessRandomOpBase {
104 public:
105 using StatelessRandomOpBase::StatelessRandomOpBase;
106
107 void Fill(OpKernelContext* context, random::PhiloxRandom random,
108 Tensor* output) override {
109 typedef typename Distribution::ResultElementType T;
110 auto flat = output->flat<T>();
111 // Reuse the compute kernels from the stateful random ops
112 functor::FillPhiloxRandom<Device, Distribution>()(
113 context, context->eigen_device<Device>(), /*key=*/nullptr,
114 /*counter=*/nullptr, random, flat.data(), flat.size(), Distribution());
115 }
116};
117
118template <typename Device, typename IntType>
119class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
120 public:
121 using StatelessRandomOpBase::StatelessRandomOpBase;
122
123 void Fill(OpKernelContext* context, random::PhiloxRandom random,
124 Tensor* output) override {
125 const Tensor& minval = context->input(2);
126 const Tensor& maxval = context->input(3);
127 OP_REQUIRES(context, TensorShapeUtils::IsScalar(minval.shape()),
128 errors::InvalidArgument("minval must be 0-D, got shape ",
129 minval.shape().DebugString()));
130 OP_REQUIRES(context, TensorShapeUtils::IsScalar(maxval.shape()),
131 errors::InvalidArgument("maxval must be 0-D, got shape ",
132 maxval.shape().DebugString()));
133
134 // Verify that minval < maxval. Note that we'll never reach this point for
135 // empty output. Zero impossible things are fine.
136 const auto lo = minval.scalar<IntType>()();
137 const auto hi = maxval.scalar<IntType>()();
138 OP_REQUIRES(
139 context, lo < hi,
140 errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
141
142 // Build distribution
143 typedef random::UniformDistribution<random::PhiloxRandom, IntType>
144 Distribution;
145 Distribution dist(lo, hi);
146
147 auto flat = output->flat<IntType>();
148 // Reuse the compute kernels from the stateful random ops
149 functor::FillPhiloxRandom<Device, Distribution>()(
150 context, context->eigen_device<Device>(), /*key=*/nullptr,
151 /*counter=*/nullptr, random, flat.data(), flat.size(), dist);
152 }
153};
154
155template <typename Device, typename IntType>
156class StatelessRandomUniformFullIntOp : public StatelessRandomOpBase {
157 public:
158 using StatelessRandomOpBase::StatelessRandomOpBase;
159
160 void Fill(OpKernelContext* context, random::PhiloxRandom random,
161 Tensor* output) override {
162 // Build distribution
163 typedef random::UniformFullIntDistribution<random::PhiloxRandom, IntType>
164 Distribution;
165 Distribution dist;
166
167 auto flat = output->flat<IntType>();
168 // Reuse the compute kernels from the stateful random ops
169 functor::FillPhiloxRandom<Device, Distribution>()(
170 context, context->eigen_device<Device>(), /*key=*/nullptr,
171 /*counter=*/nullptr, random, flat.data(), flat.size(), dist);
172 }
173};
174
175// Samples from one or more Poisson distributions.
176template <typename T, typename U>
177class StatelessRandomPoissonOp : public StatelessRandomOpBase {
178 public:
179 using StatelessRandomOpBase::StatelessRandomOpBase;
180
181 void Fill(OpKernelContext* ctx, random::PhiloxRandom random,
182 Tensor* output) override {
183 const Tensor& rate_t = ctx->input(2);
184
185 TensorShape samples_shape = output->shape();
186 OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(samples_shape, rate_t.shape()),
187 errors::InvalidArgument(
188 "Shape passed in must end with broadcasted shape."));
189
190 const int64_t num_rate = rate_t.NumElements();
191 const int64_t samples_per_rate = samples_shape.num_elements() / num_rate;
192 const auto rate_flat = rate_t.flat<T>().data();
193 auto samples_flat = output->flat<U>().data();
194
195 functor::PoissonFunctor<CPUDevice, T, U>()(
196 ctx, ctx->eigen_device<CPUDevice>(), rate_flat, num_rate,
197 samples_per_rate, random, samples_flat);
198 }
199
200 private:
201 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomPoissonOp);
202};
203
204#define REGISTER(DEVICE, TYPE) \
205 REGISTER_KERNEL_BUILDER( \
206 Name("StatelessRandomUniform") \
207 .Device(DEVICE_##DEVICE) \
208 .HostMemory("shape") \
209 .HostMemory("seed") \
210 .TypeConstraint<TYPE>("dtype"), \
211 StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \
212 random::PhiloxRandom, TYPE> >); \
213 REGISTER_KERNEL_BUILDER( \
214 Name("StatelessRandomNormal") \
215 .Device(DEVICE_##DEVICE) \
216 .HostMemory("shape") \
217 .HostMemory("seed") \
218 .TypeConstraint<TYPE>("dtype"), \
219 StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \
220 random::PhiloxRandom, TYPE> >); \
221 REGISTER_KERNEL_BUILDER( \
222 Name("StatelessTruncatedNormal") \
223 .Device(DEVICE_##DEVICE) \
224 .HostMemory("shape") \
225 .HostMemory("seed") \
226 .TypeConstraint<TYPE>("dtype"), \
227 StatelessRandomOp< \
228 DEVICE##Device, \
229 random::TruncatedNormalDistribution< \
230 random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
231
232#define REGISTER_FULL_INT(DEVICE, TYPE) \
233 REGISTER_KERNEL_BUILDER( \
234 Name("StatelessRandomUniformFullInt") \
235 .Device(DEVICE_##DEVICE) \
236 .HostMemory("shape") \
237 .HostMemory("seed") \
238 .TypeConstraint<TYPE>("dtype"), \
239 StatelessRandomUniformFullIntOp<DEVICE##Device, TYPE>)
240
241#define REGISTER_INT(DEVICE, TYPE) \
242 REGISTER_FULL_INT(DEVICE, TYPE); \
243 REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \
244 .Device(DEVICE_##DEVICE) \
245 .HostMemory("shape") \
246 .HostMemory("seed") \
247 .HostMemory("minval") \
248 .HostMemory("maxval") \
249 .TypeConstraint<TYPE>("dtype"), \
250 StatelessRandomUniformIntOp<DEVICE##Device, TYPE>)
251
252#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE)
253#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
254#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE)
255#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE)
256#define REGISTER_FULL_INT_CPU(TYPE) REGISTER_FULL_INT(CPU, TYPE)
257#define REGISTER_FULL_INT_GPU(TYPE) REGISTER_FULL_INT(GPU, TYPE)
258
259TF_CALL_half(REGISTER_CPU);
260TF_CALL_bfloat16(REGISTER_CPU);
261TF_CALL_float(REGISTER_CPU);
262TF_CALL_double(REGISTER_CPU);
263TF_CALL_int32(REGISTER_INT_CPU);
264TF_CALL_int64(REGISTER_INT_CPU);
265TF_CALL_uint32(REGISTER_FULL_INT_CPU);
266TF_CALL_uint64(REGISTER_FULL_INT_CPU);
267
268#define REGISTER_POISSON(RATE_TYPE, OUT_TYPE) \
269 REGISTER_KERNEL_BUILDER(Name("StatelessRandomPoisson") \
270 .Device(DEVICE_CPU) \
271 .HostMemory("shape") \
272 .HostMemory("seed") \
273 .HostMemory("lam") \
274 .TypeConstraint<RATE_TYPE>("Rtype") \
275 .TypeConstraint<OUT_TYPE>("dtype"), \
276 StatelessRandomPoissonOp<RATE_TYPE, OUT_TYPE>)
277
278#define REGISTER_ALL_POISSON(RATE_TYPE) \
279 REGISTER_POISSON(RATE_TYPE, Eigen::half); \
280 REGISTER_POISSON(RATE_TYPE, float); \
281 REGISTER_POISSON(RATE_TYPE, double); \
282 REGISTER_POISSON(RATE_TYPE, int32); \
283 REGISTER_POISSON(RATE_TYPE, int64_t)
284
285TF_CALL_half(REGISTER_ALL_POISSON);
286TF_CALL_float(REGISTER_ALL_POISSON);
287TF_CALL_double(REGISTER_ALL_POISSON);
288TF_CALL_int32(REGISTER_ALL_POISSON);
289TF_CALL_int64(REGISTER_ALL_POISSON);
290
291#undef REGISTER_ALL_POISSON
292#undef REGISTER_POISSON
293
294#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
295
296TF_CALL_half(REGISTER_GPU);
297TF_CALL_float(REGISTER_GPU);
298TF_CALL_double(REGISTER_GPU);
299TF_CALL_int32(REGISTER_INT_GPU);
300TF_CALL_int64(REGISTER_INT_GPU);
301TF_CALL_uint32(REGISTER_FULL_INT_GPU);
302TF_CALL_uint64(REGISTER_FULL_INT_GPU);
303
304#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
305
306#undef REGISTER
307#undef REGISTER_INT
308#undef REGISTER_CPU
309#undef REGISTER_GPU
310#undef REGISTER_INT_CPU
311#undef REGISTER_INT_GPU
312#undef REGISTER_FULL_INT_CPU
313#undef REGISTER_FULL_INT_GPU
314
315} // namespace
316
317} // namespace tensorflow
318