1 | #pragma once |
---|---|
2 | |
3 | #include <ATen/core/Generator.h> |
4 | #include <ATen/core/MT19937RNGEngine.h> |
5 | #include <c10/core/GeneratorImpl.h> |
6 | #include <c10/util/Optional.h> |
7 | |
8 | namespace at { |
9 | |
10 | struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl { |
11 | // Constructors |
12 | CPUGeneratorImpl(uint64_t seed_in = default_rng_seed_val); |
13 | ~CPUGeneratorImpl() override = default; |
14 | |
15 | // CPUGeneratorImpl methods |
16 | std::shared_ptr<CPUGeneratorImpl> clone() const; |
17 | void set_current_seed(uint64_t seed) override; |
18 | uint64_t current_seed() const override; |
19 | uint64_t seed() override; |
20 | void set_state(const c10::TensorImpl& new_state) override; |
21 | c10::intrusive_ptr<c10::TensorImpl> get_state() const override; |
22 | static DeviceType device_type(); |
23 | uint32_t random(); |
24 | uint64_t random64(); |
25 | c10::optional<float> next_float_normal_sample(); |
26 | c10::optional<double> next_double_normal_sample(); |
27 | void set_next_float_normal_sample(c10::optional<float> randn); |
28 | void set_next_double_normal_sample(c10::optional<double> randn); |
29 | at::mt19937 engine(); |
30 | void set_engine(at::mt19937 engine); |
31 | |
32 | private: |
33 | CPUGeneratorImpl* clone_impl() const override; |
34 | at::mt19937 engine_; |
35 | c10::optional<float> next_float_normal_sample_; |
36 | c10::optional<double> next_double_normal_sample_; |
37 | }; |
38 | |
39 | namespace detail { |
40 | |
41 | TORCH_API const Generator& getDefaultCPUGenerator(); |
42 | TORCH_API Generator |
43 | createCPUGenerator(uint64_t seed_val = default_rng_seed_val); |
44 | |
45 | } // namespace detail |
46 | |
47 | } // namespace at |
48 |