1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Base/IO.h" |
18 | #include "glow/Support/Random.h" |
19 | |
20 | #include "llvm/ADT/SmallString.h" |
21 | #include "llvm/Support/FileSystem.h" |
22 | |
23 | #include "gtest/gtest.h" |
24 | |
25 | using namespace glow; |
26 | |
27 | // Test that nextRandInt generates every number in the closed interval [lb, ub]. |
28 | // Use enough trials that the probability of random failure is < 1.0e-9. |
29 | TEST(Utils, PRNGBasics) { |
30 | PseudoRNG PRNG; |
31 | constexpr int lb = -3; |
32 | constexpr int ub = 3; |
33 | constexpr int trials = 200; |
34 | |
35 | for (int i = lb; i <= ub; i++) { |
36 | int j = 0; |
37 | for (; j < trials; j++) { |
38 | if (PRNG.nextRandInt(lb, ub) == i) { |
39 | break; |
40 | } |
41 | } |
42 | EXPECT_LT(j, trials); |
43 | } |
44 | } |
45 | |
46 | // Test that two default-constructed PseudoRNG objects do in fact generate |
47 | // identical sequences. |
48 | TEST(Utils, deterministicPRNG) { |
49 | PseudoRNG genA, genB; |
50 | std::uniform_int_distribution<int> dist(0, 100000); |
51 | |
52 | for (unsigned i = 0; i != 100; i++) { |
53 | EXPECT_EQ(dist(genA), dist(genB)); |
54 | } |
55 | } |
56 | |
57 | TEST(Utils, readWriteTensor) { |
58 | llvm::SmallString<64> path; |
59 | llvm::sys::fs::createTemporaryFile("tensor" , "bin" , path); |
60 | Tensor output(ElemKind::FloatTy, {2, 1, 4}); |
61 | output.getHandle() = {1, 2, 3, 4, 5, 6, 7, 8}; |
62 | writeToFile(output, path); |
63 | Tensor input; |
64 | readFromFile(input, path); |
65 | llvm::sys::fs::remove(path); |
66 | EXPECT_TRUE(output.isEqual(input)); |
67 | } |
68 | |