1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_ |
18 | |
19 | #include "tensorflow/core/lib/random/philox_random.h" |
20 | #include "tensorflow/core/platform/types.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | using random::PhiloxRandom; |
25 | |
26 | // The following 2 functions use the contract "lower 32 bits for the first |
27 | // uint32, higher 32 bits for the second". Note that this is endian-neutral, |
28 | // unlike a direct memory copy `memcpy(output, &input, 8)`. |
29 | PHILOX_DEVICE_INLINE void Uint64ToUint32s(uint64 input, uint32* output1, |
30 | uint32* output2) { |
31 | *output1 = static_cast<uint32>(input); |
32 | *output2 = static_cast<uint32>(input >> 32); |
33 | } |
34 | |
35 | PHILOX_DEVICE_INLINE uint64 Uint32sToUint64(uint32 input1, uint32 input2) { |
36 | auto u64_1 = static_cast<uint64>(input1); |
37 | auto u64_2 = static_cast<uint64>(input2); |
38 | return u64_1 | (u64_2 << 32); |
39 | } |
40 | |
41 | PHILOX_DEVICE_INLINE PhiloxRandom::ResultType GetCounterFromMem( |
42 | uint64 const* ptr) { |
43 | PhiloxRandom::ResultType counter; |
44 | Uint64ToUint32s(ptr[0], &counter[0], &counter[1]); |
45 | Uint64ToUint32s(ptr[1], &counter[2], &counter[3]); |
46 | return counter; |
47 | } |
48 | |
49 | PHILOX_DEVICE_INLINE void WriteCounterToMem( |
50 | PhiloxRandom::ResultType const& counter, uint64* ptr) { |
51 | ptr[0] = Uint32sToUint64(counter[0], counter[1]); |
52 | ptr[1] = Uint32sToUint64(counter[2], counter[3]); |
53 | } |
54 | |
55 | PHILOX_DEVICE_INLINE PhiloxRandom::Key GetKeyFromMem(uint64 const* ptr) { |
56 | PhiloxRandom::Key key; |
57 | Uint64ToUint32s(ptr[0], &key[0], &key[1]); |
58 | return key; |
59 | } |
60 | |
61 | PHILOX_DEVICE_INLINE void WriteKeyToMem(PhiloxRandom::Key const& key, |
62 | uint64* ptr) { |
63 | *ptr = Uint32sToUint64(key[0], key[1]); |
64 | } |
65 | |
66 | PHILOX_DEVICE_INLINE PhiloxRandom GetPhiloxRandomFromCounterKeyMem( |
67 | uint64 const* counter_ptr, uint64 const* key_ptr) { |
68 | return PhiloxRandom(GetCounterFromMem(counter_ptr), GetKeyFromMem(key_ptr)); |
69 | } |
70 | |
71 | } // end namespace tensorflow |
72 | |
73 | #endif // TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_ |
74 | |