1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "tensorflow/core/framework/rng_alg.h" |
19 | #include "tensorflow/core/framework/tensor_util.h" |
20 | #include "tensorflow/core/kernels/fill_functor.h" |
21 | #include "tensorflow/core/kernels/random_op_cpu.h" |
22 | #include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h" |
23 | #include "tensorflow/core/kernels/training_op_helpers.h" |
24 | #include "tensorflow/core/lib/random/random.h" |
25 | #include "tensorflow/core/platform/errors.h" |
26 | #include "tensorflow/core/platform/statusor.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | namespace functor { |
31 | |
32 | template <typename Distribution> |
33 | struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> { |
34 | void operator()(OpKernelContext* ctx, const CPUDevice& device, |
35 | Distribution dist, UpdateVariableAndFill_Philox_Arg* arg, |
36 | typename Distribution::ResultElementType* output_data) |
37 | TF_UNLOCK_FUNCTION() { |
38 | int64_t output_size = arg->output_size; |
39 | int64_t alg_tag_skip = arg->alg_tag_skip; |
40 | ScopedUnlockUnrefVar* state_var_guard = arg->state_var_guard; |
41 | Tensor* state_tensor = arg->state_tensor; |
42 | |
43 | auto state_tensor_flat = state_tensor->flat<StateElementType>(); |
44 | auto state_data = state_tensor_flat.data(); |
45 | // Delegates to PhiloxRandom to do the actual increasing. |
46 | auto philox = GetPhiloxRandomFromMem(state_data + alg_tag_skip); |
47 | UpdateMemWithPhiloxRandom(philox, output_size, state_data + alg_tag_skip); |
48 | // No longer needs the lock. |
49 | state_var_guard->Release(); |
50 | functor::FillPhiloxRandom<CPUDevice, Distribution>()( |
51 | ctx, device, /*key=*/nullptr, /*counter=*/nullptr, philox, output_data, |
52 | output_size, dist); |
53 | } |
54 | }; |
55 | |
56 | } // end namespace functor |
57 | |
58 | Status CheckState(const Tensor& state) { |
59 | if (state.dtype() != STATE_ELEMENT_DTYPE) { |
60 | return errors::InvalidArgument("dtype of RNG state variable must be " , |
61 | DataTypeString(STATE_ELEMENT_DTYPE), |
62 | ", not " , DataTypeString(state.dtype())); |
63 | } |
64 | if (state.dims() != 1) { |
65 | return errors::InvalidArgument( |
66 | "RNG state must have one and only one dimension, not " , state.dims()); |
67 | } |
68 | return OkStatus(); |
69 | } |
70 | |
71 | Status CheckPhiloxState(const Tensor& state, int64_t alg_tag_skip = 0) { |
72 | static_assert(std::is_same<StateElementType, int64_t>::value, |
73 | "StateElementType must be int64" ); |
74 | static_assert(std::is_same<PhiloxRandom::ResultElementType, uint32>::value, |
75 | "PhiloxRandom::ResultElementType must be uint32" ); |
76 | auto min_size = alg_tag_skip + PHILOX_MIN_STATE_SIZE; |
77 | if (state.NumElements() < min_size) { |
78 | return errors::InvalidArgument( |
79 | "For the Philox algorithm, the size of state" |
80 | " must be at least " , |
81 | min_size, "; got " , state.NumElements()); |
82 | } |
83 | return OkStatus(); |
84 | } |
85 | |
86 | template <typename AlgEnumType> |
87 | StatusOr<AlgEnumType> GetAlgId(OpKernelContext* ctx, int input_idx) { |
88 | AlgEnumType alg_id; |
89 | TF_RETURN_IF_ERROR(GetScalar(ctx->input(input_idx), input_idx, &alg_id)); |
90 | return alg_id; |
91 | } |
92 | |
93 | template <typename AlgEnumType> |
94 | StatusOr<ConcreteRngAlgorithm> ResolveAlg(AlgEnumType alg_id) { |
95 | switch (alg_id) { |
96 | case RNG_ALG_PHILOX: |
97 | return ConcreteRngAlgorithm::RNG_ALG_PHILOX; |
98 | case RNG_ALG_THREEFRY: |
99 | return ConcreteRngAlgorithm::RNG_ALG_THREEFRY; |
100 | case RNG_ALG_AUTO_SELECT: |
101 | // On non-XLA kernels, we pick Philox as the auto-selected algorithm. |
102 | return ConcreteRngAlgorithm::RNG_ALG_PHILOX; |
103 | default: |
104 | return errors::InvalidArgument("Unsupported algorithm id: " , alg_id); |
105 | } |
106 | } |
107 | |
108 | template <typename AlgEnumType> |
109 | StatusOr<ConcreteRngAlgorithm> GetAlg(OpKernelContext* ctx, int input_idx) { |
110 | TF_ASSIGN_OR_RETURN(auto alg_id, GetAlgId<AlgEnumType>(ctx, input_idx)); |
111 | return ResolveAlg(alg_id); |
112 | } |
113 | |
114 | template <typename Device, typename Distribution> |
115 | Status UpdateVariableAndFill( |
116 | OpKernelContext* ctx, Distribution dist, int state_input_idx, |
117 | bool read_alg_from_state, ConcreteRngAlgorithm alg, int64_t output_size, |
118 | typename Distribution::ResultElementType* output_data) { |
119 | Var* var = nullptr; |
120 | TF_RETURN_IF_ERROR( |
121 | LookupResource(ctx, HandleFromInput(ctx, state_input_idx), &var)); |
122 | // Use `ScopedUnlockUnrefVar` here instead of `mutex_lock` and `ScopedUnref` |
123 | // because the former supports early releasing which is needed by |
124 | // `UpdateVariableAndFill_Philox<CPU>` to avoid holding the lock while |
125 | // filling. |
126 | ScopedUnlockUnrefVar state_var_guard(var); |
127 | Tensor* var_tensor = var->tensor(); |
128 | TF_RETURN_IF_ERROR(CheckState(*var_tensor)); |
129 | auto var_tensor_flat = var_tensor->flat<StateElementType>(); |
130 | int64_t alg_tag_skip = 0; |
131 | if (read_alg_from_state) { |
132 | alg_tag_skip = 1; |
133 | if (var_tensor_flat.size() < 1) { |
134 | return errors::InvalidArgument("Size of tensor must be at least 1" ); |
135 | } |
136 | auto alg_id = var_tensor_flat(0); |
137 | TF_ASSIGN_OR_RETURN(alg, ResolveAlg(alg_id)); |
138 | } |
139 | switch (alg) { |
140 | case ConcreteRngAlgorithm::RNG_ALG_PHILOX: |
141 | TF_RETURN_IF_ERROR(CheckPhiloxState(*var_tensor, alg_tag_skip)); |
142 | TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, StateElementType>( |
143 | ctx, var_tensor, var->copy_on_read_mode.load())); |
144 | |
145 | UpdateVariableAndFill_Philox_Arg arg; |
146 | arg.output_size = output_size; |
147 | arg.alg_tag_skip = alg_tag_skip; |
148 | arg.state_var_guard = &state_var_guard; |
149 | arg.state_tensor = var_tensor; |
150 | functor::UpdateVariableAndFill_Philox<Device, Distribution>()( |
151 | ctx, ctx->eigen_device<Device>(), dist, &arg, output_data); |
152 | return OkStatus(); |
153 | case ConcreteRngAlgorithm::RNG_ALG_THREEFRY: |
154 | return errors::Unimplemented( |
155 | "Non-XLA devices don't support the ThreeFry algorithm." ); |
156 | } |
157 | return errors::Internal( |
158 | "This point shouldn't have been reached because the above switch should " |
159 | "have handled all algorithms." ); |
160 | } |
161 | |
162 | // Precondition: input(0) is an existing resource. |
163 | template <typename Device, class Distribution> |
164 | void StatefulRandomCompute(OpKernelContext* ctx, Distribution dist, |
165 | int state_input_idx, int shape_input_idx, |
166 | bool read_alg_from_state, ConcreteRngAlgorithm alg) { |
167 | using T = typename Distribution::ResultElementType; |
168 | const Tensor& shape_t = ctx->input(shape_input_idx); |
169 | TensorShape shape; |
170 | OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_t, &shape)); |
171 | Tensor* output; |
172 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output)); |
173 | auto output_flat = output->flat<T>(); |
174 | OP_REQUIRES_OK(ctx, UpdateVariableAndFill<Device>( |
175 | ctx, dist, state_input_idx, read_alg_from_state, alg, |
176 | output_flat.size(), output_flat.data())); |
177 | } |
178 | |
179 | template <typename Device, class Distribution> |
180 | class StatefulRandomOp : public OpKernel { |
181 | public: |
182 | explicit StatefulRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
183 | |
184 | void Compute(OpKernelContext* ctx) override { |
185 | StatefulRandomCompute<Device>( |
186 | ctx, Distribution(), 0, 1, true, |
187 | ConcreteRngAlgorithm::RNG_ALG_PHILOX /*dummy*/); |
188 | } |
189 | }; |
190 | |
191 | template <typename T> |
192 | Status GetScalar(const Tensor& tensor, int input_idx, T* result) { |
193 | auto dtype = DataTypeToEnum<T>::v(); |
194 | if (tensor.dims() != 0) { |
195 | return errors::InvalidArgument("input " , std::to_string(input_idx), |
196 | " (0-based) must have shape [], not " , |
197 | tensor.shape().DebugString()); |
198 | } |
199 | if (tensor.dtype() != dtype) { |
200 | return errors::InvalidArgument("dtype of input " , std::to_string(input_idx), |
201 | " (0-based) must be " , DataTypeString(dtype), |
202 | ", not " , DataTypeString(tensor.dtype())); |
203 | } |
204 | *result = tensor.flat<T>()(0); |
205 | return OkStatus(); |
206 | } |
207 | |
208 | template <typename Device, class Distribution> |
209 | class StatefulRandomOpV2 : public OpKernel { |
210 | public: |
211 | explicit StatefulRandomOpV2(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
212 | |
213 | void Compute(OpKernelContext* ctx) override { |
214 | OP_REQUIRES_VALUE(auto alg, ctx, GetAlg<int64_t>(ctx, 1)); |
215 | StatefulRandomCompute<Device>(ctx, Distribution(), /*state_input_idx=*/0, |
216 | /*shape_input_idx=*/2, |
217 | /*read_alg_from_state=*/false, alg); |
218 | } |
219 | }; |
220 | |
221 | template <typename Device, class IntType> |
222 | class StatefulUniformIntOp : public OpKernel { |
223 | public: |
224 | explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
225 | |
226 | void Compute(OpKernelContext* ctx) override { |
227 | OP_REQUIRES_VALUE(auto alg, ctx, GetAlg<int64_t>(ctx, 1)); |
228 | const Tensor& minval = ctx->input(3); |
229 | const Tensor& maxval = ctx->input(4); |
230 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()), |
231 | errors::InvalidArgument("minval must be 0-D, got shape " , |
232 | minval.shape().DebugString())); |
233 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()), |
234 | errors::InvalidArgument("maxval must be 0-D, got shape " , |
235 | maxval.shape().DebugString())); |
236 | |
237 | // Verify that minval < maxval. This check intentionally happens after the |
238 | // early exit for empty output. Zero impossible things are fine. |
239 | IntType lo = minval.scalar<IntType>()(); |
240 | IntType hi = maxval.scalar<IntType>()(); |
241 | OP_REQUIRES( |
242 | ctx, lo < hi, |
243 | errors::InvalidArgument("Need minval < maxval, got " , lo, " >= " , hi)); |
244 | |
245 | // Build distribution |
246 | typedef random::UniformDistribution<random::PhiloxRandom, IntType> |
247 | Distribution; |
248 | Distribution dist(lo, hi); |
249 | |
250 | StatefulRandomCompute<Device>(ctx, dist, /*state_input_idx=*/0, |
251 | /*shape_input_idx=*/2, |
252 | /*read_alg_from_state=*/false, alg); |
253 | } |
254 | }; |
255 | |
256 | template <typename Device, class IntType> |
257 | class StatefulUniformFullIntOp : public OpKernel { |
258 | public: |
259 | explicit StatefulUniformFullIntOp(OpKernelConstruction* ctx) |
260 | : OpKernel(ctx) {} |
261 | |
262 | void Compute(OpKernelContext* ctx) override { |
263 | OP_REQUIRES_VALUE(auto alg, ctx, GetAlg<int64_t>(ctx, 1)); |
264 | StatefulRandomCompute<Device>( |
265 | ctx, |
266 | random::UniformFullIntDistribution<random::PhiloxRandom, IntType>(), |
267 | /*state_input_idx=*/0, /*shape_input_idx=*/2, |
268 | /*read_alg_from_state=*/false, alg); |
269 | } |
270 | }; |
271 | |
272 | namespace functor { |
273 | |
274 | template <> |
275 | struct RngSkip_Philox<CPUDevice> { |
276 | void operator()(const CPUDevice& device, const StateElementType* in_data, |
277 | uint64 delta, StateElementType* out_data) { |
278 | // Delegates to PhiloxRandom to do the actual increasing. |
279 | auto counter = GetCounterFromMem(reinterpret_cast<const uint64*>(in_data)); |
280 | UpdateCounterMemWithPhiloxRandom(counter, delta, out_data); |
281 | } |
282 | }; |
283 | |
284 | } // end namespace functor |
285 | |
286 | template <typename Device, typename AlgEnumType = int64_t, |
287 | typename DeltaType = int64_t, bool read_old_value = false> |
288 | class RngSkipOp : public OpKernel { |
289 | public: |
290 | explicit RngSkipOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
291 | |
292 | void Compute(OpKernelContext* ctx) override { |
293 | auto state_input_idx = 0; |
294 | auto alg_input_idx = 1; |
295 | auto delta_input_idx = 2; |
296 | // GetAlg will treat RNG_ALG_AUTO_SELECT as RNG_ALG_PHILOX. |
297 | OP_REQUIRES_VALUE(auto alg, ctx, GetAlg<AlgEnumType>(ctx, alg_input_idx)); |
298 | DeltaType delta_; |
299 | OP_REQUIRES_OK( |
300 | ctx, GetScalar(ctx->input(delta_input_idx), delta_input_idx, &delta_)); |
301 | uint64 delta = static_cast<uint64>(delta_); |
302 | Var* var = nullptr; |
303 | OP_REQUIRES_OK( |
304 | ctx, LookupResource(ctx, HandleFromInput(ctx, state_input_idx), &var)); |
305 | ScopedUnlockUnrefVar state_var_guard(var); |
306 | Tensor* var_tensor = var->tensor(); |
307 | OP_REQUIRES_OK(ctx, CheckState(*var_tensor)); |
308 | using T = StateElementType; |
309 | OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, T>( |
310 | ctx, var_tensor, var->copy_on_read_mode.load())); |
311 | if (read_old_value) { |
312 | Tensor* output; |
313 | OP_REQUIRES_OK( |
314 | ctx, ctx->allocate_output(0, {RNG_MAX_COUNTER_SIZE + RNG_KEY_SIZE}, |
315 | &output)); |
316 | auto output_flat = output->flat<T>(); |
317 | if (RNG_MAX_COUNTER_SIZE > GetCounterSize(alg)) { |
318 | functor::SetZeroFunctor<Device, T>()(ctx->eigen_device<Device>(), |
319 | output_flat); |
320 | } |
321 | functor::DenseUpdate<Device, T, ASSIGN>()( |
322 | ctx->eigen_device<Device>(), output_flat, |
323 | const_cast<const Tensor*>(var_tensor)->flat<T>()); |
324 | } |
325 | switch (alg) { |
326 | case ConcreteRngAlgorithm::RNG_ALG_PHILOX: { |
327 | OP_REQUIRES_OK(ctx, CheckPhiloxState(*var_tensor)); |
328 | // var_tensor layout is counter+key, so var_tensor data is also counter |
329 | // data. |
330 | auto counter_data = var_tensor->flat<T>().data(); |
331 | functor::RngSkip_Philox<Device>()(ctx->eigen_device<Device>(), |
332 | counter_data, delta, counter_data); |
333 | break; |
334 | } |
335 | case ConcreteRngAlgorithm::RNG_ALG_THREEFRY: { |
336 | OP_REQUIRES( |
337 | ctx, false, |
338 | errors::Unimplemented( |
339 | "Non-XLA devices don't support the ThreeFry algorithm." )); |
340 | break; |
341 | } |
342 | } |
343 | } |
344 | }; |
345 | |
346 | template <typename T> |
347 | class NonDeterministicIntsOp : public OpKernel { |
348 | public: |
349 | explicit NonDeterministicIntsOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
350 | OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype" , &dtype_)); |
351 | } |
352 | |
353 | void Compute(OpKernelContext* ctx) override { |
354 | const Tensor& shape_t = ctx->input(0); |
355 | TensorShape shape; |
356 | OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_t, &shape)); |
357 | Tensor* output; |
358 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output)); |
359 | if (shape.num_elements() == 0) return; |
360 | |
361 | switch (dtype_) { |
362 | case DT_INT32: |
363 | case DT_UINT32: |
364 | case DT_INT64: |
365 | case DT_UINT64: { |
366 | auto output_flat = output->flat<T>(); |
367 | auto data = output_flat.data(); |
368 | for (int64_t i = 0; i < output_flat.size(); ++i) { |
369 | data[i] = static_cast<T>(random::New64()); |
370 | } |
371 | break; |
372 | } |
373 | default: |
374 | OP_REQUIRES(ctx, false, |
375 | errors::InvalidArgument("Unsupported dtype: " , |
376 | DataTypeString(dtype_))); |
377 | } |
378 | } |
379 | |
380 | private: |
381 | DataType dtype_; |
382 | }; |
383 | |
384 | // So far the 'Distribution' type parameter is only used when the algorithm is |
385 | // philox, so 'NormalDistribution<PhiloxRandom, ...>' is fine for now. |
386 | #define REGISTER_FloatOps(DEVICE, TYPE) \ |
387 | REGISTER_KERNEL_BUILDER( \ |
388 | Name("StatefulStandardNormalV2") \ |
389 | .Device(DEVICE_##DEVICE) \ |
390 | .HostMemory("resource") \ |
391 | .HostMemory("algorithm") \ |
392 | .HostMemory("shape") \ |
393 | .TypeConstraint<TYPE>("dtype"), \ |
394 | StatefulRandomOpV2<DEVICE##Device, \ |
395 | random::NormalDistribution<PhiloxRandom, TYPE> >); \ |
396 | REGISTER_KERNEL_BUILDER( \ |
397 | Name("StatefulUniform") \ |
398 | .Device(DEVICE_##DEVICE) \ |
399 | .HostMemory("resource") \ |
400 | .HostMemory("algorithm") \ |
401 | .HostMemory("shape") \ |
402 | .TypeConstraint<TYPE>("dtype"), \ |
403 | StatefulRandomOpV2<DEVICE##Device, \ |
404 | random::UniformDistribution<PhiloxRandom, TYPE> >); \ |
405 | REGISTER_KERNEL_BUILDER( \ |
406 | Name("StatefulTruncatedNormal") \ |
407 | .Device(DEVICE_##DEVICE) \ |
408 | .HostMemory("resource") \ |
409 | .HostMemory("algorithm") \ |
410 | .HostMemory("shape") \ |
411 | .TypeConstraint<TYPE>("dtype"), \ |
412 | StatefulRandomOpV2< \ |
413 | DEVICE##Device, \ |
414 | random::TruncatedNormalDistribution< \ |
415 | random::SingleSampleAdapter<PhiloxRandom>, TYPE> >); |
416 | |
417 | // CPU also has the deprecated 'StatefulStandardNormal' op for backward |
418 | // compatibility. |
419 | #define REGISTER_FloatOps_CPU(TYPE) \ |
420 | REGISTER_FloatOps(CPU, TYPE) REGISTER_KERNEL_BUILDER( \ |
421 | Name("StatefulStandardNormal") \ |
422 | .Device(DEVICE_CPU) \ |
423 | .HostMemory("resource") \ |
424 | .HostMemory("shape") \ |
425 | .TypeConstraint<TYPE>("dtype"), \ |
426 | StatefulRandomOp<CPUDevice, \ |
427 | random::NormalDistribution<PhiloxRandom, TYPE> >); |
428 | |
429 | #define REGISTER_FloatOps_GPU(TYPE) REGISTER_FloatOps(GPU, TYPE) |
430 | |
431 | TF_CALL_half(REGISTER_FloatOps_CPU); |
432 | TF_CALL_bfloat16(REGISTER_FloatOps_CPU); |
433 | TF_CALL_float(REGISTER_FloatOps_CPU); |
434 | TF_CALL_double(REGISTER_FloatOps_CPU); |
435 | |
436 | #define REGISTER_StatefulUniformInt(DEVICE, TYPE) \ |
437 | REGISTER_KERNEL_BUILDER(Name("StatefulUniformInt") \ |
438 | .Device(DEVICE_##DEVICE) \ |
439 | .HostMemory("resource") \ |
440 | .HostMemory("algorithm") \ |
441 | .HostMemory("shape") \ |
442 | .HostMemory("minval") \ |
443 | .HostMemory("maxval") \ |
444 | .TypeConstraint<TYPE>("dtype"), \ |
445 | StatefulUniformIntOp<DEVICE##Device, TYPE>); |
446 | |
447 | #define REGISTER_StatefulUniformInt_CPU(TYPE) \ |
448 | REGISTER_StatefulUniformInt(CPU, TYPE) |
449 | #define REGISTER_StatefulUniformInt_GPU(TYPE) \ |
450 | REGISTER_StatefulUniformInt(GPU, TYPE) |
451 | |
452 | TF_CALL_int32(REGISTER_StatefulUniformInt_CPU); |
453 | TF_CALL_int64(REGISTER_StatefulUniformInt_CPU); |
454 | |
455 | #define REGISTER_StatefulUniformFullInt(DEVICE, TYPE) \ |
456 | REGISTER_KERNEL_BUILDER(Name("StatefulUniformFullInt") \ |
457 | .Device(DEVICE_##DEVICE) \ |
458 | .HostMemory("resource") \ |
459 | .HostMemory("algorithm") \ |
460 | .HostMemory("shape") \ |
461 | .TypeConstraint<TYPE>("dtype"), \ |
462 | StatefulUniformFullIntOp<DEVICE##Device, TYPE>); |
463 | |
464 | #define REGISTER_StatefulUniformFullInt_CPU(TYPE) \ |
465 | REGISTER_StatefulUniformFullInt(CPU, TYPE) |
466 | #define REGISTER_StatefulUniformFullInt_GPU(TYPE) \ |
467 | REGISTER_StatefulUniformFullInt(GPU, TYPE) |
468 | |
469 | TF_CALL_int32(REGISTER_StatefulUniformFullInt_CPU); |
470 | TF_CALL_int64(REGISTER_StatefulUniformFullInt_CPU); |
471 | TF_CALL_uint32(REGISTER_StatefulUniformFullInt_CPU); |
472 | TF_CALL_uint64(REGISTER_StatefulUniformFullInt_CPU); |
473 | |
474 | // TODO(wangpeng): Remove `HostMemory("delta")` for RngReadAndSkip |
475 | #define REGISTER_RngSkip(DEVICE) \ |
476 | REGISTER_KERNEL_BUILDER(Name("RngSkip") \ |
477 | .Device(DEVICE_##DEVICE) \ |
478 | .HostMemory("resource") \ |
479 | .HostMemory("algorithm") \ |
480 | .HostMemory("delta"), \ |
481 | RngSkipOp<DEVICE##Device>); \ |
482 | REGISTER_KERNEL_BUILDER(Name("RngReadAndSkip") \ |
483 | .Device(DEVICE_##DEVICE) \ |
484 | .HostMemory("resource") \ |
485 | .HostMemory("alg") \ |
486 | .HostMemory("delta"), \ |
487 | RngSkipOp<DEVICE##Device, int32, uint64, true>); |
488 | |
489 | REGISTER_RngSkip(CPU); |
490 | |
491 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
492 | |
493 | TF_CALL_half(REGISTER_FloatOps_GPU); |
494 | TF_CALL_float(REGISTER_FloatOps_GPU); |
495 | TF_CALL_double(REGISTER_FloatOps_GPU); |
496 | TF_CALL_int32(REGISTER_StatefulUniformInt_GPU); |
497 | TF_CALL_int64(REGISTER_StatefulUniformInt_GPU); |
498 | TF_CALL_int32(REGISTER_StatefulUniformFullInt_GPU); |
499 | TF_CALL_int64(REGISTER_StatefulUniformFullInt_GPU); |
500 | TF_CALL_uint32(REGISTER_StatefulUniformFullInt_GPU); |
501 | TF_CALL_uint64(REGISTER_StatefulUniformFullInt_GPU); |
502 | REGISTER_RngSkip(GPU); |
503 | |
504 | #endif // GOOGLE_CUDA |
505 | |
506 | #undef REGISTER_StatefulUniformFullInt_GPU |
507 | #undef REGISTER_StatefulUniformFullInt_CPU |
508 | #undef REGISTER_StatefulUniformFullInt |
509 | #undef REGISTER_StatefulUniformInt_GPU |
510 | #undef REGISTER_StatefulUniformInt_CPU |
511 | #undef REGISTER_StatefulUniformInt |
512 | #undef REGISTER_FloatOps_GPU |
513 | #undef REGISTER_FloatOps_CPU |
514 | #undef REGISTER_FloatOps |
515 | |
516 | #define REGISTER_NonDeterministicInts(TYPE) \ |
517 | REGISTER_KERNEL_BUILDER(Name("NonDeterministicInts") \ |
518 | .Device(DEVICE_CPU) \ |
519 | .HostMemory("shape") \ |
520 | .TypeConstraint<TYPE>("dtype"), \ |
521 | NonDeterministicIntsOp<TYPE>); |
522 | |
523 | TF_CALL_int32(REGISTER_NonDeterministicInts); |
524 | TF_CALL_uint32(REGISTER_NonDeterministicInts); |
525 | TF_CALL_int64(REGISTER_NonDeterministicInts); |
526 | TF_CALL_uint64(REGISTER_NonDeterministicInts); |
527 | |
528 | #undef REGISTER_NonDeterministicInts |
529 | |
530 | // TODO(wangpeng): Add RNG ops for other distributions. |
531 | |
532 | } // end namespace tensorflow |
533 | |