1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
17#define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
18
19#include "tensorflow/core/kernels/random_ops_util.h"
20#include "tensorflow/core/kernels/stateful_random_ops.h"
21
22namespace tensorflow {
23
24PHILOX_DEVICE_INLINE PhiloxRandom
25GetPhiloxRandomFromMem(StateElementType const* ptr) {
26 auto ptr_ = reinterpret_cast<uint64 const*>(ptr);
27 return GetPhiloxRandomFromCounterKeyMem(ptr_, ptr_ + 2);
28}
29
30PHILOX_DEVICE_INLINE void WritePhiloxRandomToMem(PhiloxRandom const& philox,
31 StateElementType* ptr) {
32 auto ptr_ = reinterpret_cast<uint64*>(ptr);
33 WriteCounterToMem(philox.counter(), ptr_);
34 WriteKeyToMem(philox.key(), ptr_ + 2);
35}
36
37PHILOX_DEVICE_INLINE PhiloxRandom SkipPhiloxRandom(PhiloxRandom const& philox,
38 uint64 output_size) {
39 auto new_philox = philox;
40 // Multiplier 256 is the same as in FillPhiloxRandomTask; do not change it
41 // just here.
42 auto delta = output_size * 256;
43 new_philox.Skip(delta); // do the actual increasing
44 return new_philox;
45}
46
47PHILOX_DEVICE_INLINE void UpdateMemWithPhiloxRandom(PhiloxRandom const& philox,
48 uint64 output_size,
49 StateElementType* ptr) {
50 auto new_philox = SkipPhiloxRandom(philox, output_size);
51 WritePhiloxRandomToMem(new_philox, ptr);
52}
53
54PHILOX_DEVICE_INLINE void UpdateCounterMemWithPhiloxRandom(
55 PhiloxRandom::ResultType const& counter, uint64 output_size,
56 StateElementType* ptr) {
57 auto philox = PhiloxRandom(counter, PhiloxRandom::Key() /*dummy*/);
58 auto new_philox = SkipPhiloxRandom(philox, output_size);
59 WriteCounterToMem(new_philox.counter(), reinterpret_cast<uint64*>(ptr));
60}
61
62namespace functor {
63
64// A per-device helper function that does the actual work for
65// `UpdateVariableAndFill`.
66// Reason to use functor: C++ doesn't allow function-template partial
67// specialization.
68template <typename Device, typename Distribution>
69struct UpdateVariableAndFill_Philox;
70
71template <typename Device>
72struct RngSkip_Philox;
73
74} // end namespace functor
75
76using CPUDevice = Eigen::ThreadPoolDevice;
77
78class ScopedUnlockUnrefVar;
79
80struct UpdateVariableAndFill_Philox_Arg {
81 int64_t output_size;
82 int64_t alg_tag_skip;
83 ScopedUnlockUnrefVar* state_var_guard;
84 Tensor* state_tensor;
85};
86
87#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
88
89using GPUDevice = Eigen::GpuDevice;
90
91namespace functor {
92
93// Declares the partially GPU-specialized functor structs.
94// must be kept at <=6 arguments because of a gcc/clang ABI incompatibility bug
95template <typename Distribution>
96struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
97 void operator()(OpKernelContext* ctx, const GPUDevice& device,
98 Distribution dist, UpdateVariableAndFill_Philox_Arg* arg,
99 typename Distribution::ResultElementType* output_data);
100};
101
102template <>
103struct RngSkip_Philox<GPUDevice> {
104 void operator()(const GPUDevice& device, const StateElementType* in_data,
105 uint64 delta, StateElementType* out_data);
106};
107
108} // end namespace functor
109
110#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
111
112} // end namespace tensorflow
113
114#endif // TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
115