1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/jit/test_utils.h> |
4 | #include <torch/csrc/jit/ir/irparser.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | |
9 | TEST(IRTest, Attributes) { |
10 | Graph g; |
11 | auto one = attr::alpha; |
12 | auto two = attr::device; |
13 | auto three = attr::end; |
14 | auto four = attr::perm; |
15 | Node* n = g.create(Symbol::fromQualString("foo::bar" )); |
16 | Node& attr = *n; |
17 | attr.f_(one, 3.4)->i_(two, 5)->s_(three, "what" ); |
18 | ASSERT_EQ(attr.f(one), 3.4); |
19 | ASSERT_EQ(attr.s(three), "what" ); |
20 | ASSERT_EQ(attr.i(two), 5); |
21 | attr.s_(one, "no" ); |
22 | ASSERT_EQ(attr.s(one), "no" ); |
23 | ASSERT_TRUE(attr.hasAttribute(three)); |
24 | ASSERT_TRUE(!attr.hasAttribute(four)); |
25 | attr.ss_(two, {"hi" , "now" }); |
26 | ASSERT_EQ(attr.ss(two).at(1), "now" ); |
27 | |
28 | Node* n2 = g.create(Symbol::fromQualString("foo::baz" )); |
29 | Node& attr2 = *n2; |
30 | attr2.copyAttributes(attr); |
31 | ASSERT_EQ(attr2.s(one), "no" ); |
32 | attr2.f_(one, 5); |
33 | ASSERT_EQ(attr.s(one), "no" ); |
34 | ASSERT_EQ(attr2.f(one), 5); |
35 | } |
36 | |
37 | TEST(IRTest, Blocks) { |
38 | auto g = std::make_shared<Graph>(); |
39 | const auto graph_string = R"IR( |
40 | graph(%a : Tensor, |
41 | %b : Tensor, |
42 | %c : Tensor): |
43 | %2 : int = prim::Constant[value=1]() |
44 | %3 : Tensor = aten::add(%a, %b, %2) |
45 | %5 : Tensor = prim::If(%c) |
46 | block0(): |
47 | %6 : int = prim::Constant[value=1]() |
48 | %7 : Tensor = aten::add(%3, %3, %6) |
49 | -> (%7) |
50 | block1(): |
51 | %8 : int = prim::Constant[value=1]() |
52 | %9 : Tensor = aten::add(%b, %3, %8) |
53 | %10 : int = prim::Constant[value=1]() |
54 | %11 : Tensor = aten::add(%9, %3, %10) |
55 | -> (%11) |
56 | %12 : int = prim::Constant[value=1]() |
57 | %13 : Tensor = aten::add(%5, %3, %12) |
58 | return (%13))IR" ; |
59 | torch::jit::parseIR(graph_string, g.get()); |
60 | |
61 | g->lint(); |
62 | testing::FileCheck() |
63 | .check("add" ) |
64 | ->check("prim::If" ) |
65 | ->check("block0" ) |
66 | ->check("aten::add" ) |
67 | ->check("block1" ) |
68 | ->check_count("aten::add" , 3) |
69 | ->run(*g); |
70 | |
71 | // Removes block0 of the conditional |
72 | for (auto* node : g->block()->nodes()) { |
73 | if (node->kind() == prim::If) { |
74 | node->eraseBlock(0); |
75 | break; |
76 | } |
77 | } |
78 | |
79 | testing::FileCheck() |
80 | .check("add" ) |
81 | ->check("prim::If" ) |
82 | ->check("block0" ) |
83 | ->check_not("block" ) |
84 | ->run(*g); |
85 | g->lint(); |
86 | // test recursive copy of blocks works |
87 | auto g2 = g->copy(); |
88 | testing::FileCheck() |
89 | .check("add" ) |
90 | ->check("prim::If" ) |
91 | ->check("block0" ) |
92 | ->check_not("block" ) |
93 | ->run(*g2); |
94 | } |
95 | |
96 | TEST(IRTest, CommonAncestor) { |
97 | std::string input_str = R"( |
98 | graph(%x : Tensor, |
99 | %a.1 : bool, |
100 | %b.1 : bool, |
101 | %c.1 : bool): |
102 | %4 : int = prim::If(%a.1) |
103 | block0(): |
104 | %5 : int = prim::If(%b.1) |
105 | block0(): |
106 | %6 : int = prim::Constant[value=2]() |
107 | -> (%6) |
108 | block1(): |
109 | %7 : int = prim::Constant[value=3]() |
110 | -> (%7) |
111 | -> (%5) |
112 | block1(): |
113 | %8 : int = prim::If(%c.1) |
114 | block0(): |
115 | %9 : int = prim::Constant[value=4]() |
116 | -> (%9) |
117 | block1(): |
118 | %10 : int = prim::Constant[value=5]() |
119 | -> (%10) |
120 | -> (%8) |
121 | return (%4) |
122 | )" ; |
123 | |
124 | torch::jit::Graph g; |
125 | std::unordered_map<std::string, torch::jit::Value*> name_to_value; |
126 | torch::jit::parseIR(input_str, &g, name_to_value); |
127 | |
128 | std::vector<std::string> value_names{"6" , "7" , "9" , "10" }; |
129 | std::unordered_set<std::string> value_names_set( |
130 | value_names.begin(), value_names.end()); |
131 | |
132 | /* clang-format off */ |
133 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
134 | int ref_blocks_from_graph[4][4] = { |
135 | /* (6, 6), (6, 7), (6, 9), (6, 10) */ |
136 | { 2, 1, 0, 0 }, |
137 | /* (7, 6), (7, 7), (7, 9), (7, 10) */ |
138 | { 1, 2, 0, 0 }, |
139 | /* (9, 6), (9, 7), (9, 9), (9, 10) */ |
140 | { 0, 0, 2, 1, }, |
141 | /* (10, 6),(10, 7),(10, 9),(10, 10) */ |
142 | { 0, 0, 1, 2 } |
143 | }; |
144 | /* clang-format on */ |
145 | |
146 | for (size_t i = 0; i < value_names.size(); ++i) { |
147 | Value* i_val = name_to_value[value_names[i]]; |
148 | for (size_t j = 0; j < value_names.size(); ++j) { |
149 | Value* j_val = name_to_value[value_names[j]]; |
150 | Block* common_ancestor = |
151 | i_val->node()->findCommonAncestorBlockWith(j_val->node()); |
152 | int blocks_from_graph_block = |
153 | common_ancestor->param_node()->blocksFromGraphBlock(); |
154 | ASSERT_EQ(blocks_from_graph_block, ref_blocks_from_graph[i][j]); |
155 | } |
156 | } |
157 | } |
158 | |
159 | TEST(IRTest, OperatorMap) { |
160 | OperatorMap<int> op_map; |
161 | const char* literal1 = |
162 | "aten::dropout(Tensor input, float p, bool train) -> Tensor" ; |
163 | const char* literal2 = |
164 | "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor" ; |
165 | const char* literal3 = |
166 | "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor" ; |
167 | const char* literal4 = |
168 | "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor" ; |
169 | const char* literal5 = |
170 | "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor" ; |
171 | const char* literal6 = |
172 | "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor" ; |
173 | std::shared_ptr<Operator> op1 = getOperatorForLiteral(literal1); |
174 | std::shared_ptr<Operator> op2 = getOperatorForLiteral(literal2); |
175 | std::shared_ptr<Operator> op3 = getOperatorForLiteral(literal3); |
176 | std::shared_ptr<Operator> op4 = getOperatorForLiteral(literal4); |
177 | std::shared_ptr<Operator> op5 = getOperatorForLiteral(literal5); |
178 | std::shared_ptr<Operator> op6 = getOperatorForLiteral(literal6); |
179 | op_map.insert(op1, 1); |
180 | op_map.insert({{op2, 2}, {op3, 3}}); |
181 | op_map.insert({{op4, 4}, {op5, 5}}); |
182 | op_map.insert(op6, 6); |
183 | ASSERT_TRUE(op_map.contains(*op1)); |
184 | ASSERT_TRUE(op_map.contains(*op2)); |
185 | ASSERT_TRUE(op_map.contains(*op3)); |
186 | ASSERT_TRUE(op_map.contains(*op4)); |
187 | ASSERT_TRUE(op_map.contains(*op5)); |
188 | ASSERT_TRUE(op_map.contains(*op6)); |
189 | op_map.erase(op6); |
190 | op_map.erase(op3); |
191 | op_map.erase(op1); |
192 | ASSERT_FALSE(op_map.contains(*op1)); |
193 | ASSERT_FALSE(op_map.contains(*op3)); |
194 | ASSERT_FALSE(op_map.contains(*op6)); |
195 | op_map.insert(op1, 1); |
196 | ASSERT_TRUE(op_map.contains(*op1)); |
197 | c10::optional<int> o1 = op_map.find(*op1); |
198 | ASSERT_TRUE(o1.has_value()); |
199 | c10::optional<int> o2 = op_map.find(*op2); |
200 | ASSERT_TRUE(o2.has_value()); |
201 | c10::optional<int> o3 = op_map.find(*op3); |
202 | ASSERT_FALSE(o3.has_value()); |
203 | c10::optional<int> o4 = op_map.find(*op4); |
204 | ASSERT_TRUE(o4.has_value()); |
205 | c10::optional<int> o5 = op_map.find(*op5); |
206 | ASSERT_TRUE(o5.has_value()); |
207 | c10::optional<int> o6 = op_map.find(*op6); |
208 | ASSERT_FALSE(o6.has_value()); |
209 | } |
210 | |
211 | } // namespace jit |
212 | } // namespace torch |
213 | |