1#include <gtest/gtest.h>
2
3#include <sstream>
4
5#include <torch/csrc/lazy/core/shape.h>
6
7namespace torch {
8namespace lazy {
9
10TEST(ShapeTest, Basic1) {
11 auto shape = Shape();
12
13 EXPECT_STREQ(shape.to_string().c_str(), "UNKNOWN_SCALAR[]");
14 EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Undefined);
15 EXPECT_EQ(shape.dim(), 0);
16 EXPECT_TRUE(shape.sizes().empty());
17 EXPECT_THROW(shape.size(0), std::out_of_range);
18}
19
20TEST(ShapeTest, Basic2) {
21 auto shape = Shape(c10::ScalarType::Float, {1, 2, 3});
22
23 EXPECT_EQ(shape.numel(), 6);
24 EXPECT_STREQ(shape.to_string().c_str(), "Float[1,2,3]");
25 EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Float);
26 EXPECT_EQ(shape.dim(), 3);
27 EXPECT_EQ(shape.sizes().size(), 3);
28 for (int64_t i = 0; i < shape.dim(); i++) {
29 EXPECT_EQ(shape.sizes()[i], i + 1);
30 EXPECT_EQ(shape.size(i), i + 1);
31 }
32}
33
34TEST(ShapeTest, Basic3) {
35 auto shape = Shape(c10::ScalarType::Float, {});
36
37 EXPECT_STREQ(shape.to_string().c_str(), "Float[]");
38 EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Float);
39 EXPECT_EQ(shape.dim(), 0);
40 // this is surprising, but it's in line with how 0-D tensors behave
41 EXPECT_EQ(shape.numel(), 1);
42 EXPECT_TRUE(shape.sizes().empty());
43 EXPECT_THROW(shape.size(0), std::out_of_range);
44}
45
46TEST(ShapeTest, SetScalarType) {
47 auto shape = Shape();
48
49 shape.set_scalar_type(c10::ScalarType::Long);
50 EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Long);
51}
52
53TEST(ShapeTest, SetSize) {
54 auto shape1 = Shape();
55 EXPECT_THROW(shape1.set_size(0, 0), std::out_of_range);
56
57 auto shape2 = Shape(c10::ScalarType::Float, {1, 2, 3});
58 shape2.set_size(0, 3);
59 EXPECT_EQ(shape2.sizes()[0], 3);
60 EXPECT_EQ(shape2.size(0), 3);
61}
62
63TEST(ShapeTest, Equal) {
64 auto shape1 = Shape(c10::ScalarType::Float, {});
65 auto shape2 = Shape(c10::ScalarType::Float, {1, 2, 3});
66 auto shape3 = Shape(c10::ScalarType::Long, {1, 2, 3});
67 auto shape4 = Shape(c10::ScalarType::Float, {1, 2, 3});
68
69 EXPECT_FALSE(shape1 == shape2);
70 EXPECT_FALSE(shape2 == shape3);
71 EXPECT_FALSE(shape1 == shape3);
72 EXPECT_TRUE(shape2 == shape2);
73}
74
75TEST(ShapeTest, Ostream) {
76 auto shape = Shape();
77 std::stringstream ss;
78 ss << shape;
79
80 EXPECT_EQ(shape.to_string(), ss.str());
81}
82
83} // namespace lazy
84} // namespace torch
85