1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/core/ScalarType.h> |
4 | #include <c10/util/Exception.h> |
5 | #include <torch/csrc/lazy/core/config.h> |
6 | #include <torch/csrc/lazy/core/debug_util.h> |
7 | #include <torch/csrc/lazy/core/dynamic_ir.h> |
8 | #include <torch/csrc/lazy/core/ir.h> |
9 | #include <torch/csrc/lazy/core/ir_builder.h> |
10 | #include <torch/csrc/lazy/core/ir_metadata.h> |
11 | #include <torch/csrc/lazy/generated/LazyIr.h> |
12 | #include <torch/csrc/lazy/ts_backend/dynamic_ir.h> |
13 | #include <torch/csrc/lazy/ts_backend/ts_node.h> |
14 | #include <memory> |
15 | |
16 | namespace torch { |
17 | namespace lazy { |
18 | |
19 | class TestLeafNode : public Node { |
20 | public: |
21 | static OpKind ClassOpKind() { |
22 | return OpKind(); |
23 | } |
24 | |
25 | explicit TestLeafNode(size_t param) |
26 | : Node(ClassOpKind(), /* num_outputs */ 1), |
27 | hash_(Hash(param)), |
28 | param_(param) {} |
29 | ~TestLeafNode() override = default; |
30 | |
31 | const std::vector<Output>& operands() const override { |
32 | TORCH_INTERNAL_ASSERT(false, "Can't access operands of leaf node" ); |
33 | } |
34 | |
35 | const Output& operand(size_t i) const override { |
36 | TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of leaf node" ); |
37 | } |
38 | |
39 | hash_t hash() const override { |
40 | return hash_; |
41 | } |
42 | hash_t shapeHash() const override { |
43 | return hash_; |
44 | } |
45 | |
46 | private: |
47 | hash_t hash_; |
48 | size_t param_; |
49 | }; |
50 | |
51 | TEST(IrTest, BasicTest) { |
52 | NodePtr node1 = MakeNode<TestLeafNode>(1); |
53 | NodePtr node2 = MakeNode<TestLeafNode>(2); |
54 | EXPECT_NE(node1->hash(), node2->hash()); |
55 | |
56 | EXPECT_EQ(node1->num_outputs(), 1); |
57 | |
58 | const TestLeafNode* leafptr = NodeCast<TestLeafNode>(node1.get()); |
59 | EXPECT_TRUE(leafptr != nullptr); |
60 | } |
61 | |
62 | TEST(IrTest, MetaDataTest) { |
63 | bool restore_FLAGS_torch_lazy_ir_debug = FLAGS_torch_lazy_ir_debug; |
64 | FLAGS_torch_lazy_ir_debug = false; |
65 | NodePtr node = MakeNode<TestLeafNode>(1); |
66 | auto metaWithoutDebug = node->metadata(); |
67 | EXPECT_EQ(metaWithoutDebug.scope.size(), 0); |
68 | EXPECT_EQ(metaWithoutDebug.frame_info.size(), 0); |
69 | |
70 | FLAGS_torch_lazy_ir_debug = true; |
71 | node = MakeNode<TestLeafNode>(1); |
72 | auto metaWithEmptyDebug = node->metadata(); |
73 | EXPECT_EQ(metaWithEmptyDebug.scope.size(), 0); |
74 | EXPECT_EQ(metaWithEmptyDebug.frame_info.size(), 1); |
75 | |
76 | { |
77 | ScopePusher scope("TestScope" ); |
78 | node = MakeNode<TestLeafNode>(1); |
79 | auto metaWithScope = node->metadata(); |
80 | EXPECT_EQ(metaWithScope.scope, "TestScope.1" ); |
81 | EXPECT_EQ(metaWithScope.frame_info.size(), 1); |
82 | } |
83 | |
84 | SourceLocation dummySourceLocation; |
85 | dummySourceLocation.file = "file" ; |
86 | dummySourceLocation.function = "function" ; |
87 | dummySourceLocation.line = 10; |
88 | GetPythonFramesFunction() = [&]() -> std::vector<SourceLocation> { |
89 | return {dummySourceLocation}; |
90 | }; |
91 | node = MakeNode<TestLeafNode>(1); |
92 | auto metaWithSourceLoc = node->metadata(); |
93 | EXPECT_EQ(metaWithSourceLoc.scope.size(), 0); |
94 | EXPECT_EQ(metaWithSourceLoc.frame_info.size(), 1); |
95 | EXPECT_EQ(metaWithSourceLoc.frame_info[0].file, "file" ); |
96 | EXPECT_EQ(metaWithSourceLoc.frame_info[0].function, "function" ); |
97 | EXPECT_EQ(metaWithSourceLoc.frame_info[0].line, 10); |
98 | FLAGS_torch_lazy_ir_debug = restore_FLAGS_torch_lazy_ir_debug; |
99 | } |
100 | |
101 | TEST(IrTest, TsNodeTest) { |
102 | NodePtr node1 = MakeNode<TsNode>( |
103 | OpKind(at::aten::view), |
104 | Shape(), |
105 | /*num_outputs*/ 1, |
106 | /*hash_seed*/ kHashSeed); |
107 | NodePtr node2 = MakeNode<TsNode>( |
108 | OpKind(at::aten::view), |
109 | Shape(), |
110 | /*num_outputs*/ 1, |
111 | /*hash_seed*/ kHashSeed); |
112 | EXPECT_EQ(node1->hash(), node2->hash()); |
113 | |
114 | EXPECT_EQ(node1->num_outputs(), 1); |
115 | |
116 | const TsNode* leafptr = dynamic_cast<const TsNode*>(node1.get()); |
117 | EXPECT_TRUE(leafptr != nullptr); |
118 | } |
119 | |
120 | TEST(IrTest, DimensionNodeTest) { |
121 | const size_t DIM0 = 5; |
122 | const size_t DIM1 = 8; |
123 | NodePtr node1 = MakeNode<TsNode>( |
124 | OpKind(at::aten::view), |
125 | Shape(c10::kFloat, {DIM0, DIM1}), |
126 | /*num_outputs*/ 1, |
127 | /*hash_seed*/ kHashSeed); |
128 | |
129 | auto size0 = |
130 | std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 0)); |
131 | auto size1 = |
132 | std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 1)); |
133 | |
134 | ASSERT_EQ(DIM0, size0->getStaticValue()); |
135 | ASSERT_EQ(DIM1, size1->getStaticValue()); |
136 | |
137 | NodePtr size0_np = size0; |
138 | auto size0_dn = std::dynamic_pointer_cast<DimensionNode>(size0_np); |
139 | ASSERT_EQ(DIM0, size0_dn->getStaticValue()); |
140 | |
141 | auto add_dim = std::dynamic_pointer_cast<SizeAdd>( |
142 | MakeNode<SizeAdd>(Value{size0}, Value{size1})); |
143 | ASSERT_EQ(DIM0 + DIM1, add_dim->getStaticValue()); |
144 | |
145 | auto mul_dim = std::dynamic_pointer_cast<SizeMul>( |
146 | MakeNode<SizeMul>(Value{size0}, Value{size1})); |
147 | ASSERT_EQ(DIM0 * DIM1, mul_dim->getStaticValue()); |
148 | } |
149 | |
150 | TEST(IrTest, DimensionIsDynamicTest) { |
151 | const size_t DIM0 = 5; |
152 | const size_t DIM1 = 8; |
153 | const auto shape = Shape(c10::kFloat, {DIM0, DIM1}); |
154 | NodePtr node1 = MakeNode<TsNode>( |
155 | OpKind(at::aten::view), |
156 | shape.with_symbolic_dims(std::vector<bool>{true, false}), |
157 | /*num_outputs*/ 1, |
158 | /*hash_seed*/ kHashSeed); |
159 | |
160 | auto size0 = |
161 | std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 0)); |
162 | auto size1 = |
163 | std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 1)); |
164 | |
165 | ASSERT_EQ(true, size0->isSymbolic()); |
166 | ASSERT_EQ(false, size1->isSymbolic()); |
167 | |
168 | auto add_dim = std::dynamic_pointer_cast<SizeAdd>( |
169 | MakeNode<SizeAdd>(Value{size0}, Value{size1})); |
170 | ASSERT_EQ(true, add_dim->isSymbolic()); |
171 | |
172 | add_dim = std::dynamic_pointer_cast<SizeAdd>( |
173 | MakeNode<SizeAdd>(Value{size1}, Value{size1})); |
174 | ASSERT_EQ(false, add_dim->isSymbolic()); |
175 | |
176 | auto mul_dim = std::dynamic_pointer_cast<SizeMul>( |
177 | MakeNode<SizeMul>(Value{size0}, Value{size0})); |
178 | ASSERT_EQ(true, mul_dim->isSymbolic()); |
179 | } |
180 | |
181 | } // namespace lazy |
182 | } // namespace torch |
183 | |