1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_UTIL_GUARDED_PHILOX_RANDOM_H_ |
17 | #define TENSORFLOW_CORE_UTIL_GUARDED_PHILOX_RANDOM_H_ |
18 | |
19 | #include "tensorflow/core/framework/op_kernel.h" |
20 | #include "tensorflow/core/lib/random/philox_random.h" |
21 | #include "tensorflow/core/platform/macros.h" |
22 | #include "tensorflow/core/platform/mutex.h" |
23 | #include "tensorflow/core/platform/types.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | // A thread safe wrapper around a Philox generator. Example usage: |
28 | // |
29 | // GuardedRandomPhilox generator; |
30 | // generator.Init(context); |
31 | // |
32 | // // In thread safe code |
33 | // const int samples = ...; |
34 | // auto local_generator = generator.ReserveSamples128(samples); |
35 | // for (int i = 0; i < samples; i++) |
36 | // Array<uint32, 4> sample = local_generator(); |
37 | // // Use sample |
38 | // } |
39 | // |
40 | class GuardedPhiloxRandom { |
41 | public: |
42 | // Must call Init to finish initialization |
43 | GuardedPhiloxRandom() : initialized_(false) {} |
44 | |
45 | // Initialize the generator from attributes "seed" and "seed2". |
46 | // If both seeds are unspecified, use random seeds. |
47 | // Must be called exactly once. |
48 | Status Init(OpKernelConstruction* context); |
49 | |
50 | // Initialize with given seeds. |
51 | void Init(int64_t seed, int64_t seed2); |
52 | void Init(random::PhiloxRandom::ResultType counter, |
53 | random::PhiloxRandom::Key key); |
54 | |
55 | // Reserve a certain number of 128-bit samples. |
56 | // This function is thread safe. The returned generator is valid for the |
57 | // given number of samples, and can be used without a lock. |
58 | random::PhiloxRandom ReserveSamples128(int64_t samples); |
59 | |
60 | // Reserve a certain number of 32-bit samples. |
61 | random::PhiloxRandom ReserveSamples32(int64_t samples) { |
62 | return ReserveSamples128((samples + 3) / 4); |
63 | } |
64 | |
65 | // Reserve enough random samples in the generator for the given output count. |
66 | random::PhiloxRandom ReserveRandomOutputs(int64_t output_count, |
67 | int multiplier) { |
68 | int64_t conservative_sample_count = output_count * multiplier; |
69 | return ReserveSamples128(conservative_sample_count); |
70 | } |
71 | |
72 | private: |
73 | mutex mu_; |
74 | random::PhiloxRandom generator_ TF_GUARDED_BY(mu_); |
75 | bool initialized_; |
76 | |
77 | TF_DISALLOW_COPY_AND_ASSIGN(GuardedPhiloxRandom); |
78 | }; |
79 | |
80 | } // namespace tensorflow |
81 | |
82 | #endif // TENSORFLOW_CORE_UTIL_GUARDED_PHILOX_RANDOM_H_ |
83 | |