1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/stateless_random_ops_v2.h" |
17 | |
18 | #include "tensorflow/core/framework/bounds_check.h" |
19 | #include "tensorflow/core/framework/op_kernel.h" |
20 | #include "tensorflow/core/framework/register_types.h" |
21 | #include "tensorflow/core/framework/rng_alg.h" |
22 | #include "tensorflow/core/framework/tensor.h" |
23 | #include "tensorflow/core/framework/tensor_shape.h" |
24 | #include "tensorflow/core/framework/tensor_util.h" |
25 | #include "tensorflow/core/kernels/random_op.h" |
26 | #include "tensorflow/core/kernels/random_ops_util.h" |
27 | #include "tensorflow/core/kernels/random_poisson_op.h" |
28 | #include "tensorflow/core/kernels/stateless_random_ops.h" |
29 | #include "tensorflow/core/kernels/stateless_random_ops_v2_util.h" |
30 | #include "tensorflow/core/lib/random/random_distributions.h" |
31 | #include "tensorflow/core/platform/logging.h" |
32 | #include "tensorflow/core/util/work_sharder.h" |
33 | |
34 | #if EIGEN_COMP_GNUC && __cplusplus > 199711L |
35 | #define DISABLE_FLOAT_EQUALITY_WARNING \ |
36 | _Pragma("GCC diagnostic push") \ |
37 | _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"") |
38 | #define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop") |
39 | #else |
40 | #define DISABLE_FLOAT_EQUALITY_WARNING |
41 | #define ENABLE_FLOAT_EQUALITY_WARNING |
42 | #endif |
43 | |
44 | namespace tensorflow { |
45 | |
46 | using CPUDevice = Eigen::ThreadPoolDevice; |
47 | using GPUDevice = Eigen::GpuDevice; |
48 | |
49 | namespace { |
50 | |
51 | class StatelessRandomOpBase : public OpKernel { |
52 | public: |
53 | explicit StatelessRandomOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
54 | |
55 | void Compute(OpKernelContext* ctx) override { |
56 | OP_REQUIRES_VALUE(auto key_counter_alg, ctx, |
57 | GetKeyCounterAlgFromInputs(ctx, 1, 2, 3)); |
58 | auto key_t = std::get<0>(key_counter_alg); |
59 | auto counter_t = std::get<1>(key_counter_alg); |
60 | auto alg = std::get<2>(key_counter_alg); |
61 | |
62 | TensorShape shape; |
63 | OP_REQUIRES_OK(ctx, tensor::MakeShape(ctx->input(0), &shape)); |
64 | |
65 | // Allocate output |
66 | Tensor* output; |
67 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output)); |
68 | if (shape.num_elements() == 0) { |
69 | return; |
70 | } |
71 | |
72 | // Fill in the random numbers |
73 | Fill(ctx, alg, key_t, counter_t, output); |
74 | } |
75 | |
76 | // The part of Compute that depends on device, type, and distribution. |
77 | // Must be a tail call because it doesn't report error via return value. |
78 | virtual void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key, |
79 | const Tensor& counter, Tensor* output) = 0; |
80 | }; |
81 | |
82 | template <typename Device, typename Distribution> |
83 | class StatelessRandomOp : public StatelessRandomOpBase { |
84 | public: |
85 | using StatelessRandomOpBase::StatelessRandomOpBase; |
86 | |
87 | void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key, |
88 | const Tensor& counter, Tensor* output) override { |
89 | typedef typename Distribution::ResultElementType T; |
90 | auto flat = output->flat<T>(); |
91 | if (alg == RNG_ALG_PHILOX) { |
92 | // Reuse the compute kernels from the stateful random ops |
93 | auto key_data = key.flat<uint64>().data(); |
94 | auto counter_data = counter.flat<uint64>().data(); |
95 | functor::FillPhiloxRandom<Device, Distribution>()( |
96 | ctx, ctx->eigen_device<Device>(), key_data, counter_data, |
97 | random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(), |
98 | Distribution()); |
99 | } else { |
100 | OP_REQUIRES(ctx, false, |
101 | errors::InvalidArgument("Unsupported algorithm id: " , alg)); |
102 | } |
103 | } |
104 | }; |
105 | |
106 | template <typename Device, typename IntType> |
107 | class StatelessRandomUniformIntOp : public StatelessRandomOpBase { |
108 | public: |
109 | using StatelessRandomOpBase::StatelessRandomOpBase; |
110 | |
111 | void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key, |
112 | const Tensor& counter, Tensor* output) override { |
113 | const Tensor& minval = ctx->input(4); |
114 | const Tensor& maxval = ctx->input(5); |
115 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()), |
116 | errors::InvalidArgument("minval must be 0-D, got shape " , |
117 | minval.shape().DebugString())); |
118 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()), |
119 | errors::InvalidArgument("maxval must be 0-D, got shape " , |
120 | maxval.shape().DebugString())); |
121 | |
122 | // Verify that minval < maxval. Note that we'll never reach this point for |
123 | // empty output. Zero impossible things are fine. |
124 | const auto lo = minval.scalar<IntType>()(); |
125 | const auto hi = maxval.scalar<IntType>()(); |
126 | OP_REQUIRES( |
127 | ctx, lo < hi, |
128 | errors::InvalidArgument("Need minval < maxval, got " , lo, " >= " , hi)); |
129 | |
130 | // Build distribution |
131 | typedef random::UniformDistribution<random::PhiloxRandom, IntType> |
132 | Distribution; |
133 | Distribution dist(lo, hi); |
134 | |
135 | auto flat = output->flat<IntType>(); |
136 | if (alg == RNG_ALG_PHILOX) { |
137 | // Reuse the compute kernels from the stateful random ops |
138 | auto key_data = key.flat<uint64>().data(); |
139 | auto counter_data = counter.flat<uint64>().data(); |
140 | functor::FillPhiloxRandom<Device, Distribution>()( |
141 | ctx, ctx->eigen_device<Device>(), key_data, counter_data, |
142 | random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(), dist); |
143 | } else { |
144 | OP_REQUIRES(ctx, false, |
145 | errors::InvalidArgument("Unsupported algorithm id: " , alg)); |
146 | } |
147 | } |
148 | }; |
149 | |
150 | template <typename Device, typename IntType> |
151 | class StatelessRandomUniformFullIntOp : public StatelessRandomOpBase { |
152 | public: |
153 | using StatelessRandomOpBase::StatelessRandomOpBase; |
154 | |
155 | void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key, |
156 | const Tensor& counter, Tensor* output) override { |
157 | // Build distribution |
158 | typedef random::UniformFullIntDistribution<random::PhiloxRandom, IntType> |
159 | Distribution; |
160 | Distribution dist; |
161 | |
162 | auto flat = output->flat<IntType>(); |
163 | if (alg == RNG_ALG_PHILOX) { |
164 | // Reuse the compute kernels from the stateful random ops |
165 | auto key_data = key.flat<uint64>().data(); |
166 | auto counter_data = counter.flat<uint64>().data(); |
167 | functor::FillPhiloxRandom<Device, Distribution>()( |
168 | ctx, ctx->eigen_device<Device>(), key_data, counter_data, |
169 | random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(), dist); |
170 | } else { |
171 | OP_REQUIRES(ctx, false, |
172 | errors::InvalidArgument("Unsupported algorithm id: " , alg)); |
173 | } |
174 | } |
175 | }; |
176 | |
177 | class GetKeyCounterAlgOp : public OpKernel { |
178 | public: |
179 | explicit GetKeyCounterAlgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
180 | |
181 | void Compute(OpKernelContext* ctx) override { |
182 | const Tensor& seed_t = ctx->input(0); |
183 | OP_REQUIRES(ctx, seed_t.dims() == 1 && seed_t.dim_size(0) == 2, |
184 | errors::InvalidArgument("seed must have shape [2], not " , |
185 | seed_t.shape().DebugString())); |
186 | // Allocate outputs |
187 | Tensor* key_output; |
188 | OP_REQUIRES_OK( |
189 | ctx, ctx->allocate_output(0, TensorShape({RNG_KEY_SIZE}), &key_output)); |
190 | Tensor* counter_output; |
191 | OP_REQUIRES_OK(ctx, |
192 | ctx->allocate_output(1, TensorShape({RNG_MAX_COUNTER_SIZE}), |
193 | &counter_output)); |
194 | Tensor* alg_output; |
195 | OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({}), &alg_output)); |
196 | |
197 | random::PhiloxRandom::Key key; |
198 | random::PhiloxRandom::ResultType counter; |
199 | OP_REQUIRES_OK(ctx, GenerateKey(seed_t, &key, &counter)); |
200 | WriteKeyToMem(key, key_output->flat<uint64>().data()); |
201 | WriteCounterToMem(counter, counter_output->flat<uint64>().data()); |
202 | alg_output->flat<int>()(0) = RNG_ALG_PHILOX; |
203 | } |
204 | }; |
205 | |
206 | class GetKeyCounterOp : public OpKernel { |
207 | public: |
208 | explicit GetKeyCounterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
209 | |
210 | void Compute(OpKernelContext* ctx) override { |
211 | const Tensor& seed_t = ctx->input(0); |
212 | OP_REQUIRES(ctx, seed_t.dims() == 1 && seed_t.dim_size(0) == 2, |
213 | errors::InvalidArgument("seed must have shape [2], not " , |
214 | seed_t.shape().DebugString())); |
215 | // Allocate outputs |
216 | Tensor* key_output; |
217 | OP_REQUIRES_OK( |
218 | ctx, ctx->allocate_output(0, TensorShape({RNG_KEY_SIZE}), &key_output)); |
219 | Tensor* counter_output; |
220 | OP_REQUIRES_OK(ctx, |
221 | ctx->allocate_output(1, TensorShape({RNG_MAX_COUNTER_SIZE}), |
222 | &counter_output)); |
223 | |
224 | random::PhiloxRandom::Key key; |
225 | random::PhiloxRandom::ResultType counter; |
226 | OP_REQUIRES_OK(ctx, GenerateKey(seed_t, &key, &counter)); |
227 | WriteKeyToMem(key, key_output->flat<uint64>().data()); |
228 | WriteCounterToMem(counter, counter_output->flat<uint64>().data()); |
229 | } |
230 | }; |
231 | |
232 | class GetAlgOp : public OpKernel { |
233 | public: |
234 | explicit GetAlgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
235 | |
236 | void Compute(OpKernelContext* ctx) override { |
237 | Tensor* alg_output; |
238 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &alg_output)); |
239 | alg_output->flat<int>()(0) = RNG_ALG_PHILOX; |
240 | } |
241 | }; |
242 | |
243 | #define REGISTER(DEVICE, TYPE) \ |
244 | REGISTER_KERNEL_BUILDER( \ |
245 | Name("StatelessRandomUniformV2") \ |
246 | .Device(DEVICE_##DEVICE) \ |
247 | .HostMemory("shape") \ |
248 | .HostMemory("alg") \ |
249 | .TypeConstraint<TYPE>("dtype"), \ |
250 | StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \ |
251 | random::PhiloxRandom, TYPE> >); \ |
252 | REGISTER_KERNEL_BUILDER( \ |
253 | Name("StatelessRandomNormalV2") \ |
254 | .Device(DEVICE_##DEVICE) \ |
255 | .HostMemory("shape") \ |
256 | .HostMemory("alg") \ |
257 | .TypeConstraint<TYPE>("dtype"), \ |
258 | StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \ |
259 | random::PhiloxRandom, TYPE> >); \ |
260 | REGISTER_KERNEL_BUILDER( \ |
261 | Name("StatelessTruncatedNormalV2") \ |
262 | .Device(DEVICE_##DEVICE) \ |
263 | .HostMemory("shape") \ |
264 | .HostMemory("alg") \ |
265 | .TypeConstraint<TYPE>("dtype"), \ |
266 | StatelessRandomOp< \ |
267 | DEVICE##Device, \ |
268 | random::TruncatedNormalDistribution< \ |
269 | random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >) |
270 | |
271 | #define REGISTER_FULL_INT(DEVICE, TYPE) \ |
272 | REGISTER_KERNEL_BUILDER( \ |
273 | Name("StatelessRandomUniformFullIntV2") \ |
274 | .Device(DEVICE_##DEVICE) \ |
275 | .HostMemory("shape") \ |
276 | .HostMemory("alg") \ |
277 | .TypeConstraint<TYPE>("dtype"), \ |
278 | StatelessRandomUniformFullIntOp<DEVICE##Device, TYPE>) |
279 | |
280 | #define REGISTER_INT(DEVICE, TYPE) \ |
281 | REGISTER_FULL_INT(DEVICE, TYPE); \ |
282 | REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformIntV2") \ |
283 | .Device(DEVICE_##DEVICE) \ |
284 | .HostMemory("shape") \ |
285 | .HostMemory("alg") \ |
286 | .HostMemory("minval") \ |
287 | .HostMemory("maxval") \ |
288 | .TypeConstraint<TYPE>("dtype"), \ |
289 | StatelessRandomUniformIntOp<DEVICE##Device, TYPE>) |
290 | |
291 | #define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE) |
292 | #define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE) |
293 | #define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE) |
294 | #define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE) |
295 | #define REGISTER_FULL_INT_CPU(TYPE) REGISTER_FULL_INT(CPU, TYPE) |
296 | #define REGISTER_FULL_INT_GPU(TYPE) REGISTER_FULL_INT(GPU, TYPE) |
297 | |
298 | TF_CALL_half(REGISTER_CPU); |
299 | TF_CALL_bfloat16(REGISTER_CPU); |
300 | TF_CALL_float(REGISTER_CPU); |
301 | TF_CALL_double(REGISTER_CPU); |
302 | TF_CALL_int32(REGISTER_INT_CPU); |
303 | TF_CALL_int64(REGISTER_INT_CPU); |
304 | TF_CALL_uint32(REGISTER_FULL_INT_CPU); |
305 | TF_CALL_uint64(REGISTER_FULL_INT_CPU); |
306 | |
307 | #define REGISTER_GET_KCA(DEVICE) \ |
308 | REGISTER_KERNEL_BUILDER(Name("StatelessRandomGetKeyCounterAlg") \ |
309 | .Device(DEVICE_##DEVICE) \ |
310 | .HostMemory("seed") \ |
311 | .HostMemory("key") \ |
312 | .HostMemory("counter") \ |
313 | .HostMemory("alg"), \ |
314 | GetKeyCounterAlgOp) \ |
315 | REGISTER_KERNEL_BUILDER(Name("StatelessRandomGetKeyCounter") \ |
316 | .Device(DEVICE_##DEVICE) \ |
317 | .HostMemory("seed") \ |
318 | .HostMemory("key") \ |
319 | .HostMemory("counter"), \ |
320 | GetKeyCounterOp) \ |
321 | REGISTER_KERNEL_BUILDER( \ |
322 | Name("StatelessRandomGetAlg").Device(DEVICE_##DEVICE).HostMemory("alg"), \ |
323 | GetAlgOp) |
324 | |
325 | REGISTER_GET_KCA(CPU); |
326 | |
327 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
328 | |
329 | TF_CALL_half(REGISTER_GPU); |
330 | TF_CALL_float(REGISTER_GPU); |
331 | TF_CALL_double(REGISTER_GPU); |
332 | TF_CALL_int32(REGISTER_INT_GPU); |
333 | TF_CALL_int64(REGISTER_INT_GPU); |
334 | TF_CALL_uint32(REGISTER_FULL_INT_GPU); |
335 | TF_CALL_uint64(REGISTER_FULL_INT_GPU); |
336 | |
337 | REGISTER_GET_KCA(GPU); |
338 | |
339 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
340 | |
341 | #undef REGISTER |
342 | #undef REGISTER_INT |
343 | #undef REGISTER_CPU |
344 | #undef REGISTER_GPU |
345 | #undef REGISTER_INT_CPU |
346 | #undef REGISTER_INT_GPU |
347 | #undef REGISTER_FULL_INT_CPU |
348 | #undef REGISTER_FULL_INT_GPU |
349 | |
350 | #undef REGISTER_GET_KCA |
351 | |
352 | } // namespace |
353 | |
354 | } // namespace tensorflow |
355 | |