1 | #include <gtest/gtest.h> |
2 | |
3 | #include <sstream> |
4 | |
5 | #include <torch/csrc/lazy/core/shape.h> |
6 | |
7 | namespace torch { |
8 | namespace lazy { |
9 | |
10 | TEST(ShapeTest, Basic1) { |
11 | auto shape = Shape(); |
12 | |
13 | EXPECT_STREQ(shape.to_string().c_str(), "UNKNOWN_SCALAR[]" ); |
14 | EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Undefined); |
15 | EXPECT_EQ(shape.dim(), 0); |
16 | EXPECT_TRUE(shape.sizes().empty()); |
17 | EXPECT_THROW(shape.size(0), std::out_of_range); |
18 | } |
19 | |
20 | TEST(ShapeTest, Basic2) { |
21 | auto shape = Shape(c10::ScalarType::Float, {1, 2, 3}); |
22 | |
23 | EXPECT_EQ(shape.numel(), 6); |
24 | EXPECT_STREQ(shape.to_string().c_str(), "Float[1,2,3]" ); |
25 | EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Float); |
26 | EXPECT_EQ(shape.dim(), 3); |
27 | EXPECT_EQ(shape.sizes().size(), 3); |
28 | for (int64_t i = 0; i < shape.dim(); i++) { |
29 | EXPECT_EQ(shape.sizes()[i], i + 1); |
30 | EXPECT_EQ(shape.size(i), i + 1); |
31 | } |
32 | } |
33 | |
34 | TEST(ShapeTest, Basic3) { |
35 | auto shape = Shape(c10::ScalarType::Float, {}); |
36 | |
37 | EXPECT_STREQ(shape.to_string().c_str(), "Float[]" ); |
38 | EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Float); |
39 | EXPECT_EQ(shape.dim(), 0); |
40 | // this is surprising, but it's in line with how 0-D tensors behave |
41 | EXPECT_EQ(shape.numel(), 1); |
42 | EXPECT_TRUE(shape.sizes().empty()); |
43 | EXPECT_THROW(shape.size(0), std::out_of_range); |
44 | } |
45 | |
46 | TEST(ShapeTest, SetScalarType) { |
47 | auto shape = Shape(); |
48 | |
49 | shape.set_scalar_type(c10::ScalarType::Long); |
50 | EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Long); |
51 | } |
52 | |
53 | TEST(ShapeTest, SetSize) { |
54 | auto shape1 = Shape(); |
55 | EXPECT_THROW(shape1.set_size(0, 0), std::out_of_range); |
56 | |
57 | auto shape2 = Shape(c10::ScalarType::Float, {1, 2, 3}); |
58 | shape2.set_size(0, 3); |
59 | EXPECT_EQ(shape2.sizes()[0], 3); |
60 | EXPECT_EQ(shape2.size(0), 3); |
61 | } |
62 | |
63 | TEST(ShapeTest, Equal) { |
64 | auto shape1 = Shape(c10::ScalarType::Float, {}); |
65 | auto shape2 = Shape(c10::ScalarType::Float, {1, 2, 3}); |
66 | auto shape3 = Shape(c10::ScalarType::Long, {1, 2, 3}); |
67 | auto shape4 = Shape(c10::ScalarType::Float, {1, 2, 3}); |
68 | |
69 | EXPECT_FALSE(shape1 == shape2); |
70 | EXPECT_FALSE(shape2 == shape3); |
71 | EXPECT_FALSE(shape1 == shape3); |
72 | EXPECT_TRUE(shape2 == shape2); |
73 | } |
74 | |
75 | TEST(ShapeTest, Ostream) { |
76 | auto shape = Shape(); |
77 | std::stringstream ss; |
78 | ss << shape; |
79 | |
80 | EXPECT_EQ(shape.to_string(), ss.str()); |
81 | } |
82 | |
83 | } // namespace lazy |
84 | } // namespace torch |
85 | |