1#include <gtest/gtest.h>
2
3#include <c10/core/ScalarType.h>
4#include <c10/util/Exception.h>
5#include <torch/csrc/lazy/core/config.h>
6#include <torch/csrc/lazy/core/debug_util.h>
7#include <torch/csrc/lazy/core/dynamic_ir.h>
8#include <torch/csrc/lazy/core/ir.h>
9#include <torch/csrc/lazy/core/ir_builder.h>
10#include <torch/csrc/lazy/core/ir_metadata.h>
11#include <torch/csrc/lazy/generated/LazyIr.h>
12#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
13#include <torch/csrc/lazy/ts_backend/ts_node.h>
14#include <memory>
15
16namespace torch {
17namespace lazy {
18
19class TestLeafNode : public Node {
20 public:
21 static OpKind ClassOpKind() {
22 return OpKind();
23 }
24
25 explicit TestLeafNode(size_t param)
26 : Node(ClassOpKind(), /* num_outputs */ 1),
27 hash_(Hash(param)),
28 param_(param) {}
29 ~TestLeafNode() override = default;
30
31 const std::vector<Output>& operands() const override {
32 TORCH_INTERNAL_ASSERT(false, "Can't access operands of leaf node");
33 }
34
35 const Output& operand(size_t i) const override {
36 TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of leaf node");
37 }
38
39 hash_t hash() const override {
40 return hash_;
41 }
42 hash_t shapeHash() const override {
43 return hash_;
44 }
45
46 private:
47 hash_t hash_;
48 size_t param_;
49};
50
51TEST(IrTest, BasicTest) {
52 NodePtr node1 = MakeNode<TestLeafNode>(1);
53 NodePtr node2 = MakeNode<TestLeafNode>(2);
54 EXPECT_NE(node1->hash(), node2->hash());
55
56 EXPECT_EQ(node1->num_outputs(), 1);
57
58 const TestLeafNode* leafptr = NodeCast<TestLeafNode>(node1.get());
59 EXPECT_TRUE(leafptr != nullptr);
60}
61
62TEST(IrTest, MetaDataTest) {
63 bool restore_FLAGS_torch_lazy_ir_debug = FLAGS_torch_lazy_ir_debug;
64 FLAGS_torch_lazy_ir_debug = false;
65 NodePtr node = MakeNode<TestLeafNode>(1);
66 auto metaWithoutDebug = node->metadata();
67 EXPECT_EQ(metaWithoutDebug.scope.size(), 0);
68 EXPECT_EQ(metaWithoutDebug.frame_info.size(), 0);
69
70 FLAGS_torch_lazy_ir_debug = true;
71 node = MakeNode<TestLeafNode>(1);
72 auto metaWithEmptyDebug = node->metadata();
73 EXPECT_EQ(metaWithEmptyDebug.scope.size(), 0);
74 EXPECT_EQ(metaWithEmptyDebug.frame_info.size(), 1);
75
76 {
77 ScopePusher scope("TestScope");
78 node = MakeNode<TestLeafNode>(1);
79 auto metaWithScope = node->metadata();
80 EXPECT_EQ(metaWithScope.scope, "TestScope.1");
81 EXPECT_EQ(metaWithScope.frame_info.size(), 1);
82 }
83
84 SourceLocation dummySourceLocation;
85 dummySourceLocation.file = "file";
86 dummySourceLocation.function = "function";
87 dummySourceLocation.line = 10;
88 GetPythonFramesFunction() = [&]() -> std::vector<SourceLocation> {
89 return {dummySourceLocation};
90 };
91 node = MakeNode<TestLeafNode>(1);
92 auto metaWithSourceLoc = node->metadata();
93 EXPECT_EQ(metaWithSourceLoc.scope.size(), 0);
94 EXPECT_EQ(metaWithSourceLoc.frame_info.size(), 1);
95 EXPECT_EQ(metaWithSourceLoc.frame_info[0].file, "file");
96 EXPECT_EQ(metaWithSourceLoc.frame_info[0].function, "function");
97 EXPECT_EQ(metaWithSourceLoc.frame_info[0].line, 10);
98 FLAGS_torch_lazy_ir_debug = restore_FLAGS_torch_lazy_ir_debug;
99}
100
101TEST(IrTest, TsNodeTest) {
102 NodePtr node1 = MakeNode<TsNode>(
103 OpKind(at::aten::view),
104 Shape(),
105 /*num_outputs*/ 1,
106 /*hash_seed*/ kHashSeed);
107 NodePtr node2 = MakeNode<TsNode>(
108 OpKind(at::aten::view),
109 Shape(),
110 /*num_outputs*/ 1,
111 /*hash_seed*/ kHashSeed);
112 EXPECT_EQ(node1->hash(), node2->hash());
113
114 EXPECT_EQ(node1->num_outputs(), 1);
115
116 const TsNode* leafptr = dynamic_cast<const TsNode*>(node1.get());
117 EXPECT_TRUE(leafptr != nullptr);
118}
119
120TEST(IrTest, DimensionNodeTest) {
121 const size_t DIM0 = 5;
122 const size_t DIM1 = 8;
123 NodePtr node1 = MakeNode<TsNode>(
124 OpKind(at::aten::view),
125 Shape(c10::kFloat, {DIM0, DIM1}),
126 /*num_outputs*/ 1,
127 /*hash_seed*/ kHashSeed);
128
129 auto size0 =
130 std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 0));
131 auto size1 =
132 std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 1));
133
134 ASSERT_EQ(DIM0, size0->getStaticValue());
135 ASSERT_EQ(DIM1, size1->getStaticValue());
136
137 NodePtr size0_np = size0;
138 auto size0_dn = std::dynamic_pointer_cast<DimensionNode>(size0_np);
139 ASSERT_EQ(DIM0, size0_dn->getStaticValue());
140
141 auto add_dim = std::dynamic_pointer_cast<SizeAdd>(
142 MakeNode<SizeAdd>(Value{size0}, Value{size1}));
143 ASSERT_EQ(DIM0 + DIM1, add_dim->getStaticValue());
144
145 auto mul_dim = std::dynamic_pointer_cast<SizeMul>(
146 MakeNode<SizeMul>(Value{size0}, Value{size1}));
147 ASSERT_EQ(DIM0 * DIM1, mul_dim->getStaticValue());
148}
149
150TEST(IrTest, DimensionIsDynamicTest) {
151 const size_t DIM0 = 5;
152 const size_t DIM1 = 8;
153 const auto shape = Shape(c10::kFloat, {DIM0, DIM1});
154 NodePtr node1 = MakeNode<TsNode>(
155 OpKind(at::aten::view),
156 shape.with_symbolic_dims(std::vector<bool>{true, false}),
157 /*num_outputs*/ 1,
158 /*hash_seed*/ kHashSeed);
159
160 auto size0 =
161 std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 0));
162 auto size1 =
163 std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 1));
164
165 ASSERT_EQ(true, size0->isSymbolic());
166 ASSERT_EQ(false, size1->isSymbolic());
167
168 auto add_dim = std::dynamic_pointer_cast<SizeAdd>(
169 MakeNode<SizeAdd>(Value{size0}, Value{size1}));
170 ASSERT_EQ(true, add_dim->isSymbolic());
171
172 add_dim = std::dynamic_pointer_cast<SizeAdd>(
173 MakeNode<SizeAdd>(Value{size1}, Value{size1}));
174 ASSERT_EQ(false, add_dim->isSymbolic());
175
176 auto mul_dim = std::dynamic_pointer_cast<SizeMul>(
177 MakeNode<SizeMul>(Value{size0}, Value{size0}));
178 ASSERT_EQ(true, mul_dim->isSymbolic());
179}
180
181} // namespace lazy
182} // namespace torch
183