1#include <gtest/gtest.h>
2#include <ATen/test/rng_test.h>
3#include <ATen/Generator.h>
4#include <c10/core/GeneratorImpl.h>
5#include <ATen/Tensor.h>
6#include <ATen/native/DistributionTemplates.h>
7#include <ATen/native/cpu/DistributionTemplates.h>
8#include <torch/library.h>
9#include <c10/util/Optional.h>
10#include <torch/all.h>
11#include <stdexcept>
12
13using namespace at;
14
15#ifndef ATEN_CPU_STATIC_DISPATCH
16namespace {
17
18constexpr auto kCustomRNG = DispatchKey::CustomRNGKeyId;
19
20struct TestCPUGenerator : public c10::GeneratorImpl {
21 TestCPUGenerator(uint64_t value) : GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(kCustomRNG)}, value_(value) { }
22 // NOLINTNEXTLINE(modernize-use-override)
23 ~TestCPUGenerator() = default;
24 uint32_t random() { return value_; }
25 uint64_t random64() { return value_; }
26 c10::optional<float> next_float_normal_sample() { return next_float_normal_sample_; }
27 c10::optional<double> next_double_normal_sample() { return next_double_normal_sample_; }
28 void set_next_float_normal_sample(c10::optional<float> randn) { next_float_normal_sample_ = randn; }
29 void set_next_double_normal_sample(c10::optional<double> randn) { next_double_normal_sample_ = randn; }
30 void set_current_seed(uint64_t seed) override { throw std::runtime_error("not implemented"); }
31 uint64_t current_seed() const override { throw std::runtime_error("not implemented"); }
32 uint64_t seed() override { throw std::runtime_error("not implemented"); }
33 void set_state(const c10::TensorImpl& new_state) override { throw std::runtime_error("not implemented"); }
34 c10::intrusive_ptr<c10::TensorImpl> get_state() const override { throw std::runtime_error("not implemented"); }
35 TestCPUGenerator* clone_impl() const override { throw std::runtime_error("not implemented"); }
36
37 static DeviceType device_type() { return DeviceType::CPU; }
38
39 uint64_t value_;
40 c10::optional<float> next_float_normal_sample_;
41 c10::optional<double> next_double_normal_sample_;
42};
43
44// ==================================================== Random ========================================================
45
46Tensor& random_(Tensor& self, c10::optional<Generator> generator) {
47 return at::native::templates::random_impl<native::templates::cpu::RandomKernel, TestCPUGenerator>(self, generator);
48}
49
50Tensor& random_from_to(Tensor& self, int64_t from, optional<int64_t> to, c10::optional<Generator> generator) {
51 return at::native::templates::random_from_to_impl<native::templates::cpu::RandomFromToKernel, TestCPUGenerator>(self, from, to, generator);
52}
53
54Tensor& random_to(Tensor& self, int64_t to, c10::optional<Generator> generator) {
55 return random_from_to(self, 0, to, generator);
56}
57
58// ==================================================== Normal ========================================================
59
60Tensor& normal_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
61 return at::native::templates::normal_impl_<native::templates::cpu::NormalKernel, TestCPUGenerator>(self, mean, std, gen);
62}
63
64Tensor& normal_Tensor_float_out(const Tensor& mean, double std, c10::optional<Generator> gen, Tensor& output) {
65 return at::native::templates::normal_out_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(output, mean, std, gen);
66}
67
68Tensor& normal_float_Tensor_out(double mean, const Tensor& std, c10::optional<Generator> gen, Tensor& output) {
69 return at::native::templates::normal_out_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(output, mean, std, gen);
70}
71
72Tensor& normal_Tensor_Tensor_out(const Tensor& mean, const Tensor& std, c10::optional<Generator> gen, Tensor& output) {
73 return at::native::templates::normal_out_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(output, mean, std, gen);
74}
75
76Tensor normal_Tensor_float(const Tensor& mean, double std, c10::optional<Generator> gen) {
77 return at::native::templates::normal_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(mean, std, gen);
78}
79
80Tensor normal_float_Tensor(double mean, const Tensor& std, c10::optional<Generator> gen) {
81 return at::native::templates::normal_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(mean, std, gen);
82}
83
84Tensor normal_Tensor_Tensor(const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
85 return at::native::templates::normal_impl<native::templates::cpu::NormalKernel, TestCPUGenerator>(mean, std, gen);
86}
87
88// ==================================================== Uniform =======================================================
89
90Tensor& uniform_(Tensor& self, double from, double to, c10::optional<Generator> generator) {
91 return at::native::templates::uniform_impl_<native::templates::cpu::UniformKernel, TestCPUGenerator>(self, from, to, generator);
92}
93
94// ==================================================== Cauchy ========================================================
95
96Tensor& cauchy_(Tensor& self, double median, double sigma, c10::optional<Generator> generator) {
97 return at::native::templates::cauchy_impl_<native::templates::cpu::CauchyKernel, TestCPUGenerator>(self, median, sigma, generator);
98}
99
100// ================================================== LogNormal =======================================================
101
102Tensor& log_normal_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
103 return at::native::templates::log_normal_impl_<native::templates::cpu::LogNormalKernel, TestCPUGenerator>(self, mean, std, gen);
104}
105
106// ================================================== Geometric =======================================================
107
108Tensor& geometric_(Tensor& self, double p, c10::optional<Generator> gen) {
109 return at::native::templates::geometric_impl_<native::templates::cpu::GeometricKernel, TestCPUGenerator>(self, p, gen);
110}
111
112// ================================================== Exponential =====================================================
113
114Tensor& exponential_(Tensor& self, double lambda, c10::optional<Generator> gen) {
115 return at::native::templates::exponential_impl_<native::templates::cpu::ExponentialKernel, TestCPUGenerator>(self, lambda, gen);
116}
117
118// ================================================== Bernoulli =======================================================
119
120Tensor& bernoulli_Tensor(Tensor& self, const Tensor& p_, c10::optional<Generator> gen) {
121 return at::native::templates::bernoulli_impl_<native::templates::cpu::BernoulliKernel, TestCPUGenerator>(self, p_, gen);
122}
123
124Tensor& bernoulli_float(Tensor& self, double p, c10::optional<Generator> gen) {
125 return at::native::templates::bernoulli_impl_<native::templates::cpu::BernoulliKernel, TestCPUGenerator>(self, p, gen);
126}
127
128Tensor& bernoulli_out(const Tensor& self, c10::optional<Generator> gen, Tensor& result) {
129 return at::native::templates::bernoulli_out_impl<native::templates::cpu::BernoulliKernel, TestCPUGenerator>(result, self, gen);
130}
131
132TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) {
133 // Random
134 m.impl("random_.from", random_from_to);
135 m.impl("random_.to", random_to);
136 m.impl("random_", random_);
137 // Normal
138 m.impl("normal_", normal_);
139 m.impl("normal.Tensor_float_out", normal_Tensor_float_out);
140 m.impl("normal.float_Tensor_out", normal_float_Tensor_out);
141 m.impl("normal.Tensor_Tensor_out", normal_Tensor_Tensor_out);
142 m.impl("normal.Tensor_float", normal_Tensor_float);
143 m.impl("normal.float_Tensor", normal_float_Tensor);
144 m.impl("normal.Tensor_Tensor", normal_Tensor_Tensor);
145 m.impl("uniform_", uniform_);
146 // Cauchy
147 m.impl("cauchy_", cauchy_);
148 // LogNormal
149 m.impl("log_normal_", log_normal_);
150 // Geometric
151 m.impl("geometric_", geometric_);
152 // Exponential
153 m.impl("exponential_", exponential_);
154 // Bernoulli
155 m.impl("bernoulli.out", bernoulli_out);
156 m.impl("bernoulli_.Tensor", bernoulli_Tensor);
157 m.impl("bernoulli_.float", bernoulli_float);
158}
159
160class RNGTest : public ::testing::Test {
161};
162
163static constexpr auto MAGIC_NUMBER = 424242424242424242ULL;
164
165// ==================================================== Random ========================================================
166
167TEST_F(RNGTest, RandomFromTo) {
168 const at::Device device("cpu");
169 test_random_from_to<TestCPUGenerator, torch::kBool, bool>(device);
170 test_random_from_to<TestCPUGenerator, torch::kUInt8, uint8_t>(device);
171 test_random_from_to<TestCPUGenerator, torch::kInt8, int8_t>(device);
172 test_random_from_to<TestCPUGenerator, torch::kInt16, int16_t>(device);
173 test_random_from_to<TestCPUGenerator, torch::kInt32, int32_t>(device);
174 test_random_from_to<TestCPUGenerator, torch::kInt64, int64_t>(device);
175 test_random_from_to<TestCPUGenerator, torch::kFloat32, float>(device);
176 test_random_from_to<TestCPUGenerator, torch::kFloat64, double>(device);
177}
178
179TEST_F(RNGTest, Random) {
180 const at::Device device("cpu");
181 test_random<TestCPUGenerator, torch::kBool, bool>(device);
182 test_random<TestCPUGenerator, torch::kUInt8, uint8_t>(device);
183 test_random<TestCPUGenerator, torch::kInt8, int8_t>(device);
184 test_random<TestCPUGenerator, torch::kInt16, int16_t>(device);
185 test_random<TestCPUGenerator, torch::kInt32, int32_t>(device);
186 test_random<TestCPUGenerator, torch::kInt64, int64_t>(device);
187 test_random<TestCPUGenerator, torch::kFloat32, float>(device);
188 test_random<TestCPUGenerator, torch::kFloat64, double>(device);
189}
190
191// This test proves that Tensor.random_() distribution is able to generate unsigned 64 bit max value(64 ones)
192// https://github.com/pytorch/pytorch/issues/33299
193TEST_F(RNGTest, Random64bits) {
194 auto gen = at::make_generator<TestCPUGenerator>(std::numeric_limits<uint64_t>::max());
195 auto actual = torch::empty({1}, torch::kInt64);
196 actual.random_(std::numeric_limits<int64_t>::min(), c10::nullopt, gen);
197 ASSERT_EQ(static_cast<uint64_t>(actual[0].item<int64_t>()), std::numeric_limits<uint64_t>::max());
198}
199
200// ==================================================== Normal ========================================================
201
202TEST_F(RNGTest, Normal) {
203 const auto mean = 123.45;
204 const auto std = 67.89;
205 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
206
207 auto actual = torch::empty({10});
208 actual.normal_(mean, std, gen);
209
210 auto expected = torch::empty_like(actual);
211 native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
212
213 ASSERT_TRUE(torch::allclose(actual, expected));
214}
215
216TEST_F(RNGTest, Normal_float_Tensor_out) {
217 const auto mean = 123.45;
218 const auto std = 67.89;
219 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
220
221 auto actual = torch::empty({10});
222 at::normal_out(actual, mean, torch::full({10}, std), gen);
223
224 auto expected = torch::empty_like(actual);
225 native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
226
227 ASSERT_TRUE(torch::allclose(actual, expected));
228}
229
230TEST_F(RNGTest, Normal_Tensor_float_out) {
231 const auto mean = 123.45;
232 const auto std = 67.89;
233 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
234
235 auto actual = torch::empty({10});
236 at::normal_out(actual, torch::full({10}, mean), std, gen);
237
238 auto expected = torch::empty_like(actual);
239 native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
240
241 ASSERT_TRUE(torch::allclose(actual, expected));
242}
243
244TEST_F(RNGTest, Normal_Tensor_Tensor_out) {
245 const auto mean = 123.45;
246 const auto std = 67.89;
247 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
248
249 auto actual = torch::empty({10});
250 at::normal_out(actual, torch::full({10}, mean), torch::full({10}, std), gen);
251
252 auto expected = torch::empty_like(actual);
253 native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
254
255 ASSERT_TRUE(torch::allclose(actual, expected));
256}
257
258TEST_F(RNGTest, Normal_float_Tensor) {
259 const auto mean = 123.45;
260 const auto std = 67.89;
261 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
262
263 auto actual = at::normal(mean, torch::full({10}, std), gen);
264
265 auto expected = torch::empty_like(actual);
266 native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
267
268 ASSERT_TRUE(torch::allclose(actual, expected));
269}
270
271TEST_F(RNGTest, Normal_Tensor_float) {
272 const auto mean = 123.45;
273 const auto std = 67.89;
274 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
275
276 auto actual = at::normal(torch::full({10}, mean), std, gen);
277
278 auto expected = torch::empty_like(actual);
279 native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
280
281 ASSERT_TRUE(torch::allclose(actual, expected));
282}
283
284TEST_F(RNGTest, Normal_Tensor_Tensor) {
285 const auto mean = 123.45;
286 const auto std = 67.89;
287 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
288
289 auto actual = at::normal(torch::full({10}, mean), torch::full({10}, std), gen);
290
291 auto expected = torch::empty_like(actual);
292 native::templates::cpu::normal_kernel(expected, mean, std, check_generator<TestCPUGenerator>(gen));
293
294 ASSERT_TRUE(torch::allclose(actual, expected));
295}
296
297// ==================================================== Uniform =======================================================
298
299TEST_F(RNGTest, Uniform) {
300 const auto from = -24.24;
301 const auto to = 42.42;
302 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
303
304 auto actual = torch::empty({3, 3});
305 actual.uniform_(from, to, gen);
306
307 auto expected = torch::empty_like(actual);
308 auto iter = TensorIterator::nullary_op(expected);
309 native::templates::cpu::uniform_kernel(iter, from, to, check_generator<TestCPUGenerator>(gen));
310
311 ASSERT_TRUE(torch::allclose(actual, expected));
312}
313
314// ==================================================== Cauchy ========================================================
315
316TEST_F(RNGTest, Cauchy) {
317 const auto median = 123.45;
318 const auto sigma = 67.89;
319 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
320
321 auto actual = torch::empty({3, 3});
322 actual.cauchy_(median, sigma, gen);
323
324 auto expected = torch::empty_like(actual);
325 auto iter = TensorIterator::nullary_op(expected);
326 native::templates::cpu::cauchy_kernel(iter, median, sigma, check_generator<TestCPUGenerator>(gen));
327
328 ASSERT_TRUE(torch::allclose(actual, expected));
329}
330
331// ================================================== LogNormal =======================================================
332
333TEST_F(RNGTest, LogNormal) {
334 const auto mean = 12.345;
335 const auto std = 6.789;
336 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
337
338 auto actual = torch::empty({10});
339 actual.log_normal_(mean, std, gen);
340
341 auto expected = torch::empty_like(actual);
342 auto iter = TensorIterator::nullary_op(expected);
343 native::templates::cpu::log_normal_kernel(iter, mean, std, check_generator<TestCPUGenerator>(gen));
344
345 ASSERT_TRUE(torch::allclose(actual, expected));
346}
347
348// ================================================== Geometric =======================================================
349
350TEST_F(RNGTest, Geometric) {
351 const auto p = 0.42;
352 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
353
354 auto actual = torch::empty({3, 3});
355 actual.geometric_(p, gen);
356
357 auto expected = torch::empty_like(actual);
358 auto iter = TensorIterator::nullary_op(expected);
359 native::templates::cpu::geometric_kernel(iter, p, check_generator<TestCPUGenerator>(gen));
360
361 ASSERT_TRUE(torch::allclose(actual, expected));
362}
363
364// ================================================== Exponential =====================================================
365
366TEST_F(RNGTest, Exponential) {
367 const auto lambda = 42;
368 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
369
370 auto actual = torch::empty({3, 3});
371 actual.exponential_(lambda, gen);
372
373 auto expected = torch::empty_like(actual);
374 auto iter = TensorIterator::nullary_op(expected);
375 native::templates::cpu::exponential_kernel(iter, lambda, check_generator<TestCPUGenerator>(gen));
376
377 ASSERT_TRUE(torch::allclose(actual, expected));
378}
379
380// ==================================================== Bernoulli =====================================================
381
382TEST_F(RNGTest, Bernoulli_Tensor) {
383 const auto p = 0.42;
384 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
385
386 auto actual = torch::empty({3, 3});
387 actual.bernoulli_(torch::full({3,3}, p), gen);
388
389 auto expected = torch::empty_like(actual);
390 native::templates::cpu::bernoulli_kernel(expected, torch::full({3,3}, p), check_generator<TestCPUGenerator>(gen));
391
392 ASSERT_TRUE(torch::allclose(actual, expected));
393}
394
395TEST_F(RNGTest, Bernoulli_scalar) {
396 const auto p = 0.42;
397 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
398
399 auto actual = torch::empty({3, 3});
400 actual.bernoulli_(p, gen);
401
402 auto expected = torch::empty_like(actual);
403 native::templates::cpu::bernoulli_kernel(expected, p, check_generator<TestCPUGenerator>(gen));
404
405 ASSERT_TRUE(torch::allclose(actual, expected));
406}
407
408TEST_F(RNGTest, Bernoulli) {
409 const auto p = 0.42;
410 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
411
412 auto actual = at::bernoulli(torch::full({3,3}, p), gen);
413
414 auto expected = torch::empty_like(actual);
415 native::templates::cpu::bernoulli_kernel(expected, torch::full({3,3}, p), check_generator<TestCPUGenerator>(gen));
416
417 ASSERT_TRUE(torch::allclose(actual, expected));
418}
419
420TEST_F(RNGTest, Bernoulli_2) {
421 const auto p = 0.42;
422 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
423
424 auto actual = torch::full({3,3}, p).bernoulli(gen);
425
426 auto expected = torch::empty_like(actual);
427 native::templates::cpu::bernoulli_kernel(expected, torch::full({3,3}, p), check_generator<TestCPUGenerator>(gen));
428
429 ASSERT_TRUE(torch::allclose(actual, expected));
430}
431
432TEST_F(RNGTest, Bernoulli_p) {
433 const auto p = 0.42;
434 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
435
436 auto actual = at::bernoulli(torch::empty({3, 3}), p, gen);
437
438 auto expected = torch::empty_like(actual);
439 native::templates::cpu::bernoulli_kernel(expected, p, check_generator<TestCPUGenerator>(gen));
440
441 ASSERT_TRUE(torch::allclose(actual, expected));
442}
443
444TEST_F(RNGTest, Bernoulli_p_2) {
445 const auto p = 0.42;
446 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
447
448 auto actual = torch::empty({3, 3}).bernoulli(p, gen);
449
450 auto expected = torch::empty_like(actual);
451 native::templates::cpu::bernoulli_kernel(expected, p, check_generator<TestCPUGenerator>(gen));
452
453 ASSERT_TRUE(torch::allclose(actual, expected));
454}
455
456TEST_F(RNGTest, Bernoulli_out) {
457 const auto p = 0.42;
458 auto gen = at::make_generator<TestCPUGenerator>(MAGIC_NUMBER);
459
460 auto actual = torch::empty({3, 3});
461 at::bernoulli_out(actual, torch::full({3,3}, p), gen);
462
463 auto expected = torch::empty_like(actual);
464 native::templates::cpu::bernoulli_kernel(expected, torch::full({3,3}, p), check_generator<TestCPUGenerator>(gen));
465
466 ASSERT_TRUE(torch::allclose(actual, expected));
467}
468}
469#endif // ATEN_CPU_STATIC_DISPATCH
470