1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/stateless_random_ops_v2.h"
17
18#include "tensorflow/core/framework/bounds_check.h"
19#include "tensorflow/core/framework/op_kernel.h"
20#include "tensorflow/core/framework/register_types.h"
21#include "tensorflow/core/framework/rng_alg.h"
22#include "tensorflow/core/framework/tensor.h"
23#include "tensorflow/core/framework/tensor_shape.h"
24#include "tensorflow/core/framework/tensor_util.h"
25#include "tensorflow/core/kernels/random_op.h"
26#include "tensorflow/core/kernels/random_ops_util.h"
27#include "tensorflow/core/kernels/random_poisson_op.h"
28#include "tensorflow/core/kernels/stateless_random_ops.h"
29#include "tensorflow/core/kernels/stateless_random_ops_v2_util.h"
30#include "tensorflow/core/lib/random/random_distributions.h"
31#include "tensorflow/core/platform/logging.h"
32#include "tensorflow/core/util/work_sharder.h"
33
34#if EIGEN_COMP_GNUC && __cplusplus > 199711L
35#define DISABLE_FLOAT_EQUALITY_WARNING \
36 _Pragma("GCC diagnostic push") \
37 _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
38#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
39#else
40#define DISABLE_FLOAT_EQUALITY_WARNING
41#define ENABLE_FLOAT_EQUALITY_WARNING
42#endif
43
44namespace tensorflow {
45
46using CPUDevice = Eigen::ThreadPoolDevice;
47using GPUDevice = Eigen::GpuDevice;
48
49namespace {
50
51class StatelessRandomOpBase : public OpKernel {
52 public:
53 explicit StatelessRandomOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
54
55 void Compute(OpKernelContext* ctx) override {
56 OP_REQUIRES_VALUE(auto key_counter_alg, ctx,
57 GetKeyCounterAlgFromInputs(ctx, 1, 2, 3));
58 auto key_t = std::get<0>(key_counter_alg);
59 auto counter_t = std::get<1>(key_counter_alg);
60 auto alg = std::get<2>(key_counter_alg);
61
62 TensorShape shape;
63 OP_REQUIRES_OK(ctx, tensor::MakeShape(ctx->input(0), &shape));
64
65 // Allocate output
66 Tensor* output;
67 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output));
68 if (shape.num_elements() == 0) {
69 return;
70 }
71
72 // Fill in the random numbers
73 Fill(ctx, alg, key_t, counter_t, output);
74 }
75
76 // The part of Compute that depends on device, type, and distribution.
77 // Must be a tail call because it doesn't report error via return value.
78 virtual void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
79 const Tensor& counter, Tensor* output) = 0;
80};
81
82template <typename Device, typename Distribution>
83class StatelessRandomOp : public StatelessRandomOpBase {
84 public:
85 using StatelessRandomOpBase::StatelessRandomOpBase;
86
87 void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
88 const Tensor& counter, Tensor* output) override {
89 typedef typename Distribution::ResultElementType T;
90 auto flat = output->flat<T>();
91 if (alg == RNG_ALG_PHILOX) {
92 // Reuse the compute kernels from the stateful random ops
93 auto key_data = key.flat<uint64>().data();
94 auto counter_data = counter.flat<uint64>().data();
95 functor::FillPhiloxRandom<Device, Distribution>()(
96 ctx, ctx->eigen_device<Device>(), key_data, counter_data,
97 random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(),
98 Distribution());
99 } else {
100 OP_REQUIRES(ctx, false,
101 errors::InvalidArgument("Unsupported algorithm id: ", alg));
102 }
103 }
104};
105
106template <typename Device, typename IntType>
107class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
108 public:
109 using StatelessRandomOpBase::StatelessRandomOpBase;
110
111 void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
112 const Tensor& counter, Tensor* output) override {
113 const Tensor& minval = ctx->input(4);
114 const Tensor& maxval = ctx->input(5);
115 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
116 errors::InvalidArgument("minval must be 0-D, got shape ",
117 minval.shape().DebugString()));
118 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()),
119 errors::InvalidArgument("maxval must be 0-D, got shape ",
120 maxval.shape().DebugString()));
121
122 // Verify that minval < maxval. Note that we'll never reach this point for
123 // empty output. Zero impossible things are fine.
124 const auto lo = minval.scalar<IntType>()();
125 const auto hi = maxval.scalar<IntType>()();
126 OP_REQUIRES(
127 ctx, lo < hi,
128 errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
129
130 // Build distribution
131 typedef random::UniformDistribution<random::PhiloxRandom, IntType>
132 Distribution;
133 Distribution dist(lo, hi);
134
135 auto flat = output->flat<IntType>();
136 if (alg == RNG_ALG_PHILOX) {
137 // Reuse the compute kernels from the stateful random ops
138 auto key_data = key.flat<uint64>().data();
139 auto counter_data = counter.flat<uint64>().data();
140 functor::FillPhiloxRandom<Device, Distribution>()(
141 ctx, ctx->eigen_device<Device>(), key_data, counter_data,
142 random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(), dist);
143 } else {
144 OP_REQUIRES(ctx, false,
145 errors::InvalidArgument("Unsupported algorithm id: ", alg));
146 }
147 }
148};
149
150template <typename Device, typename IntType>
151class StatelessRandomUniformFullIntOp : public StatelessRandomOpBase {
152 public:
153 using StatelessRandomOpBase::StatelessRandomOpBase;
154
155 void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
156 const Tensor& counter, Tensor* output) override {
157 // Build distribution
158 typedef random::UniformFullIntDistribution<random::PhiloxRandom, IntType>
159 Distribution;
160 Distribution dist;
161
162 auto flat = output->flat<IntType>();
163 if (alg == RNG_ALG_PHILOX) {
164 // Reuse the compute kernels from the stateful random ops
165 auto key_data = key.flat<uint64>().data();
166 auto counter_data = counter.flat<uint64>().data();
167 functor::FillPhiloxRandom<Device, Distribution>()(
168 ctx, ctx->eigen_device<Device>(), key_data, counter_data,
169 random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(), dist);
170 } else {
171 OP_REQUIRES(ctx, false,
172 errors::InvalidArgument("Unsupported algorithm id: ", alg));
173 }
174 }
175};
176
177class GetKeyCounterAlgOp : public OpKernel {
178 public:
179 explicit GetKeyCounterAlgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
180
181 void Compute(OpKernelContext* ctx) override {
182 const Tensor& seed_t = ctx->input(0);
183 OP_REQUIRES(ctx, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
184 errors::InvalidArgument("seed must have shape [2], not ",
185 seed_t.shape().DebugString()));
186 // Allocate outputs
187 Tensor* key_output;
188 OP_REQUIRES_OK(
189 ctx, ctx->allocate_output(0, TensorShape({RNG_KEY_SIZE}), &key_output));
190 Tensor* counter_output;
191 OP_REQUIRES_OK(ctx,
192 ctx->allocate_output(1, TensorShape({RNG_MAX_COUNTER_SIZE}),
193 &counter_output));
194 Tensor* alg_output;
195 OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({}), &alg_output));
196
197 random::PhiloxRandom::Key key;
198 random::PhiloxRandom::ResultType counter;
199 OP_REQUIRES_OK(ctx, GenerateKey(seed_t, &key, &counter));
200 WriteKeyToMem(key, key_output->flat<uint64>().data());
201 WriteCounterToMem(counter, counter_output->flat<uint64>().data());
202 alg_output->flat<int>()(0) = RNG_ALG_PHILOX;
203 }
204};
205
206class GetKeyCounterOp : public OpKernel {
207 public:
208 explicit GetKeyCounterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
209
210 void Compute(OpKernelContext* ctx) override {
211 const Tensor& seed_t = ctx->input(0);
212 OP_REQUIRES(ctx, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
213 errors::InvalidArgument("seed must have shape [2], not ",
214 seed_t.shape().DebugString()));
215 // Allocate outputs
216 Tensor* key_output;
217 OP_REQUIRES_OK(
218 ctx, ctx->allocate_output(0, TensorShape({RNG_KEY_SIZE}), &key_output));
219 Tensor* counter_output;
220 OP_REQUIRES_OK(ctx,
221 ctx->allocate_output(1, TensorShape({RNG_MAX_COUNTER_SIZE}),
222 &counter_output));
223
224 random::PhiloxRandom::Key key;
225 random::PhiloxRandom::ResultType counter;
226 OP_REQUIRES_OK(ctx, GenerateKey(seed_t, &key, &counter));
227 WriteKeyToMem(key, key_output->flat<uint64>().data());
228 WriteCounterToMem(counter, counter_output->flat<uint64>().data());
229 }
230};
231
232class GetAlgOp : public OpKernel {
233 public:
234 explicit GetAlgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
235
236 void Compute(OpKernelContext* ctx) override {
237 Tensor* alg_output;
238 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &alg_output));
239 alg_output->flat<int>()(0) = RNG_ALG_PHILOX;
240 }
241};
242
243#define REGISTER(DEVICE, TYPE) \
244 REGISTER_KERNEL_BUILDER( \
245 Name("StatelessRandomUniformV2") \
246 .Device(DEVICE_##DEVICE) \
247 .HostMemory("shape") \
248 .HostMemory("alg") \
249 .TypeConstraint<TYPE>("dtype"), \
250 StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \
251 random::PhiloxRandom, TYPE> >); \
252 REGISTER_KERNEL_BUILDER( \
253 Name("StatelessRandomNormalV2") \
254 .Device(DEVICE_##DEVICE) \
255 .HostMemory("shape") \
256 .HostMemory("alg") \
257 .TypeConstraint<TYPE>("dtype"), \
258 StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \
259 random::PhiloxRandom, TYPE> >); \
260 REGISTER_KERNEL_BUILDER( \
261 Name("StatelessTruncatedNormalV2") \
262 .Device(DEVICE_##DEVICE) \
263 .HostMemory("shape") \
264 .HostMemory("alg") \
265 .TypeConstraint<TYPE>("dtype"), \
266 StatelessRandomOp< \
267 DEVICE##Device, \
268 random::TruncatedNormalDistribution< \
269 random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
270
271#define REGISTER_FULL_INT(DEVICE, TYPE) \
272 REGISTER_KERNEL_BUILDER( \
273 Name("StatelessRandomUniformFullIntV2") \
274 .Device(DEVICE_##DEVICE) \
275 .HostMemory("shape") \
276 .HostMemory("alg") \
277 .TypeConstraint<TYPE>("dtype"), \
278 StatelessRandomUniformFullIntOp<DEVICE##Device, TYPE>)
279
280#define REGISTER_INT(DEVICE, TYPE) \
281 REGISTER_FULL_INT(DEVICE, TYPE); \
282 REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformIntV2") \
283 .Device(DEVICE_##DEVICE) \
284 .HostMemory("shape") \
285 .HostMemory("alg") \
286 .HostMemory("minval") \
287 .HostMemory("maxval") \
288 .TypeConstraint<TYPE>("dtype"), \
289 StatelessRandomUniformIntOp<DEVICE##Device, TYPE>)
290
291#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE)
292#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
293#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE)
294#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE)
295#define REGISTER_FULL_INT_CPU(TYPE) REGISTER_FULL_INT(CPU, TYPE)
296#define REGISTER_FULL_INT_GPU(TYPE) REGISTER_FULL_INT(GPU, TYPE)
297
298TF_CALL_half(REGISTER_CPU);
299TF_CALL_bfloat16(REGISTER_CPU);
300TF_CALL_float(REGISTER_CPU);
301TF_CALL_double(REGISTER_CPU);
302TF_CALL_int32(REGISTER_INT_CPU);
303TF_CALL_int64(REGISTER_INT_CPU);
304TF_CALL_uint32(REGISTER_FULL_INT_CPU);
305TF_CALL_uint64(REGISTER_FULL_INT_CPU);
306
307#define REGISTER_GET_KCA(DEVICE) \
308 REGISTER_KERNEL_BUILDER(Name("StatelessRandomGetKeyCounterAlg") \
309 .Device(DEVICE_##DEVICE) \
310 .HostMemory("seed") \
311 .HostMemory("key") \
312 .HostMemory("counter") \
313 .HostMemory("alg"), \
314 GetKeyCounterAlgOp) \
315 REGISTER_KERNEL_BUILDER(Name("StatelessRandomGetKeyCounter") \
316 .Device(DEVICE_##DEVICE) \
317 .HostMemory("seed") \
318 .HostMemory("key") \
319 .HostMemory("counter"), \
320 GetKeyCounterOp) \
321 REGISTER_KERNEL_BUILDER( \
322 Name("StatelessRandomGetAlg").Device(DEVICE_##DEVICE).HostMemory("alg"), \
323 GetAlgOp)
324
325REGISTER_GET_KCA(CPU);
326
327#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
328
329TF_CALL_half(REGISTER_GPU);
330TF_CALL_float(REGISTER_GPU);
331TF_CALL_double(REGISTER_GPU);
332TF_CALL_int32(REGISTER_INT_GPU);
333TF_CALL_int64(REGISTER_INT_GPU);
334TF_CALL_uint32(REGISTER_FULL_INT_GPU);
335TF_CALL_uint64(REGISTER_FULL_INT_GPU);
336
337REGISTER_GET_KCA(GPU);
338
339#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
340
341#undef REGISTER
342#undef REGISTER_INT
343#undef REGISTER_CPU
344#undef REGISTER_GPU
345#undef REGISTER_INT_CPU
346#undef REGISTER_INT_GPU
347#undef REGISTER_FULL_INT_CPU
348#undef REGISTER_FULL_INT_GPU
349
350#undef REGISTER_GET_KCA
351
352} // namespace
353
354} // namespace tensorflow
355