1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
17#define TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
18
19#include "tensorflow/core/lib/random/philox_random.h"
20#include "tensorflow/core/platform/types.h"
21
22namespace tensorflow {
23
24using random::PhiloxRandom;
25
26// The following 2 functions use the contract "lower 32 bits for the first
27// uint32, higher 32 bits for the second". Note that this is endian-neutral,
28// unlike a direct memory copy `memcpy(output, &input, 8)`.
29PHILOX_DEVICE_INLINE void Uint64ToUint32s(uint64 input, uint32* output1,
30 uint32* output2) {
31 *output1 = static_cast<uint32>(input);
32 *output2 = static_cast<uint32>(input >> 32);
33}
34
35PHILOX_DEVICE_INLINE uint64 Uint32sToUint64(uint32 input1, uint32 input2) {
36 auto u64_1 = static_cast<uint64>(input1);
37 auto u64_2 = static_cast<uint64>(input2);
38 return u64_1 | (u64_2 << 32);
39}
40
41PHILOX_DEVICE_INLINE PhiloxRandom::ResultType GetCounterFromMem(
42 uint64 const* ptr) {
43 PhiloxRandom::ResultType counter;
44 Uint64ToUint32s(ptr[0], &counter[0], &counter[1]);
45 Uint64ToUint32s(ptr[1], &counter[2], &counter[3]);
46 return counter;
47}
48
49PHILOX_DEVICE_INLINE void WriteCounterToMem(
50 PhiloxRandom::ResultType const& counter, uint64* ptr) {
51 ptr[0] = Uint32sToUint64(counter[0], counter[1]);
52 ptr[1] = Uint32sToUint64(counter[2], counter[3]);
53}
54
55PHILOX_DEVICE_INLINE PhiloxRandom::Key GetKeyFromMem(uint64 const* ptr) {
56 PhiloxRandom::Key key;
57 Uint64ToUint32s(ptr[0], &key[0], &key[1]);
58 return key;
59}
60
61PHILOX_DEVICE_INLINE void WriteKeyToMem(PhiloxRandom::Key const& key,
62 uint64* ptr) {
63 *ptr = Uint32sToUint64(key[0], key[1]);
64}
65
66PHILOX_DEVICE_INLINE PhiloxRandom GetPhiloxRandomFromCounterKeyMem(
67 uint64 const* counter_ptr, uint64 const* key_ptr) {
68 return PhiloxRandom(GetCounterFromMem(counter_ptr), GetKeyFromMem(key_ptr));
69}
70
71} // end namespace tensorflow
72
73#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
74