1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Base/IO.h"
18#include "glow/Support/Random.h"
19
20#include "llvm/ADT/SmallString.h"
21#include "llvm/Support/FileSystem.h"
22
23#include "gtest/gtest.h"
24
25using namespace glow;
26
27// Test that nextRandInt generates every number in the closed interval [lb, ub].
28// Use enough trials that the probability of random failure is < 1.0e-9.
29TEST(Utils, PRNGBasics) {
30 PseudoRNG PRNG;
31 constexpr int lb = -3;
32 constexpr int ub = 3;
33 constexpr int trials = 200;
34
35 for (int i = lb; i <= ub; i++) {
36 int j = 0;
37 for (; j < trials; j++) {
38 if (PRNG.nextRandInt(lb, ub) == i) {
39 break;
40 }
41 }
42 EXPECT_LT(j, trials);
43 }
44}
45
46// Test that two default-constructed PseudoRNG objects do in fact generate
47// identical sequences.
48TEST(Utils, deterministicPRNG) {
49 PseudoRNG genA, genB;
50 std::uniform_int_distribution<int> dist(0, 100000);
51
52 for (unsigned i = 0; i != 100; i++) {
53 EXPECT_EQ(dist(genA), dist(genB));
54 }
55}
56
57TEST(Utils, readWriteTensor) {
58 llvm::SmallString<64> path;
59 llvm::sys::fs::createTemporaryFile("tensor", "bin", path);
60 Tensor output(ElemKind::FloatTy, {2, 1, 4});
61 output.getHandle() = {1, 2, 3, 4, 5, 6, 7, 8};
62 writeToFile(output, path);
63 Tensor input;
64 readFromFile(input, path);
65 llvm::sys::fs::remove(path);
66 EXPECT_TRUE(output.isEqual(input));
67}
68