1#include <gtest/gtest.h>
2
3#include <test/cpp/jit/test_utils.h>
4#include <torch/csrc/jit/ir/irparser.h>
5
6namespace torch {
7namespace jit {
8
9TEST(IRTest, Attributes) {
10 Graph g;
11 auto one = attr::alpha;
12 auto two = attr::device;
13 auto three = attr::end;
14 auto four = attr::perm;
15 Node* n = g.create(Symbol::fromQualString("foo::bar"));
16 Node& attr = *n;
17 attr.f_(one, 3.4)->i_(two, 5)->s_(three, "what");
18 ASSERT_EQ(attr.f(one), 3.4);
19 ASSERT_EQ(attr.s(three), "what");
20 ASSERT_EQ(attr.i(two), 5);
21 attr.s_(one, "no");
22 ASSERT_EQ(attr.s(one), "no");
23 ASSERT_TRUE(attr.hasAttribute(three));
24 ASSERT_TRUE(!attr.hasAttribute(four));
25 attr.ss_(two, {"hi", "now"});
26 ASSERT_EQ(attr.ss(two).at(1), "now");
27
28 Node* n2 = g.create(Symbol::fromQualString("foo::baz"));
29 Node& attr2 = *n2;
30 attr2.copyAttributes(attr);
31 ASSERT_EQ(attr2.s(one), "no");
32 attr2.f_(one, 5);
33 ASSERT_EQ(attr.s(one), "no");
34 ASSERT_EQ(attr2.f(one), 5);
35}
36
37TEST(IRTest, Blocks) {
38 auto g = std::make_shared<Graph>();
39 const auto graph_string = R"IR(
40 graph(%a : Tensor,
41 %b : Tensor,
42 %c : Tensor):
43 %2 : int = prim::Constant[value=1]()
44 %3 : Tensor = aten::add(%a, %b, %2)
45 %5 : Tensor = prim::If(%c)
46 block0():
47 %6 : int = prim::Constant[value=1]()
48 %7 : Tensor = aten::add(%3, %3, %6)
49 -> (%7)
50 block1():
51 %8 : int = prim::Constant[value=1]()
52 %9 : Tensor = aten::add(%b, %3, %8)
53 %10 : int = prim::Constant[value=1]()
54 %11 : Tensor = aten::add(%9, %3, %10)
55 -> (%11)
56 %12 : int = prim::Constant[value=1]()
57 %13 : Tensor = aten::add(%5, %3, %12)
58 return (%13))IR";
59 torch::jit::parseIR(graph_string, g.get());
60
61 g->lint();
62 testing::FileCheck()
63 .check("add")
64 ->check("prim::If")
65 ->check("block0")
66 ->check("aten::add")
67 ->check("block1")
68 ->check_count("aten::add", 3)
69 ->run(*g);
70
71 // Removes block0 of the conditional
72 for (auto* node : g->block()->nodes()) {
73 if (node->kind() == prim::If) {
74 node->eraseBlock(0);
75 break;
76 }
77 }
78
79 testing::FileCheck()
80 .check("add")
81 ->check("prim::If")
82 ->check("block0")
83 ->check_not("block")
84 ->run(*g);
85 g->lint();
86 // test recursive copy of blocks works
87 auto g2 = g->copy();
88 testing::FileCheck()
89 .check("add")
90 ->check("prim::If")
91 ->check("block0")
92 ->check_not("block")
93 ->run(*g2);
94}
95
96TEST(IRTest, CommonAncestor) {
97 std::string input_str = R"(
98graph(%x : Tensor,
99 %a.1 : bool,
100 %b.1 : bool,
101 %c.1 : bool):
102 %4 : int = prim::If(%a.1)
103 block0():
104 %5 : int = prim::If(%b.1)
105 block0():
106 %6 : int = prim::Constant[value=2]()
107 -> (%6)
108 block1():
109 %7 : int = prim::Constant[value=3]()
110 -> (%7)
111 -> (%5)
112 block1():
113 %8 : int = prim::If(%c.1)
114 block0():
115 %9 : int = prim::Constant[value=4]()
116 -> (%9)
117 block1():
118 %10 : int = prim::Constant[value=5]()
119 -> (%10)
120 -> (%8)
121 return (%4)
122)";
123
124 torch::jit::Graph g;
125 std::unordered_map<std::string, torch::jit::Value*> name_to_value;
126 torch::jit::parseIR(input_str, &g, name_to_value);
127
128 std::vector<std::string> value_names{"6", "7", "9", "10"};
129 std::unordered_set<std::string> value_names_set(
130 value_names.begin(), value_names.end());
131
132 /* clang-format off */
133 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
134 int ref_blocks_from_graph[4][4] = {
135 /* (6, 6), (6, 7), (6, 9), (6, 10) */
136 { 2, 1, 0, 0 },
137 /* (7, 6), (7, 7), (7, 9), (7, 10) */
138 { 1, 2, 0, 0 },
139 /* (9, 6), (9, 7), (9, 9), (9, 10) */
140 { 0, 0, 2, 1, },
141 /* (10, 6),(10, 7),(10, 9),(10, 10) */
142 { 0, 0, 1, 2 }
143 };
144 /* clang-format on */
145
146 for (size_t i = 0; i < value_names.size(); ++i) {
147 Value* i_val = name_to_value[value_names[i]];
148 for (size_t j = 0; j < value_names.size(); ++j) {
149 Value* j_val = name_to_value[value_names[j]];
150 Block* common_ancestor =
151 i_val->node()->findCommonAncestorBlockWith(j_val->node());
152 int blocks_from_graph_block =
153 common_ancestor->param_node()->blocksFromGraphBlock();
154 ASSERT_EQ(blocks_from_graph_block, ref_blocks_from_graph[i][j]);
155 }
156 }
157}
158
159TEST(IRTest, OperatorMap) {
160 OperatorMap<int> op_map;
161 const char* literal1 =
162 "aten::dropout(Tensor input, float p, bool train) -> Tensor";
163 const char* literal2 =
164 "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor";
165 const char* literal3 =
166 "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor";
167 const char* literal4 =
168 "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor";
169 const char* literal5 =
170 "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor";
171 const char* literal6 =
172 "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor";
173 std::shared_ptr<Operator> op1 = getOperatorForLiteral(literal1);
174 std::shared_ptr<Operator> op2 = getOperatorForLiteral(literal2);
175 std::shared_ptr<Operator> op3 = getOperatorForLiteral(literal3);
176 std::shared_ptr<Operator> op4 = getOperatorForLiteral(literal4);
177 std::shared_ptr<Operator> op5 = getOperatorForLiteral(literal5);
178 std::shared_ptr<Operator> op6 = getOperatorForLiteral(literal6);
179 op_map.insert(op1, 1);
180 op_map.insert({{op2, 2}, {op3, 3}});
181 op_map.insert({{op4, 4}, {op5, 5}});
182 op_map.insert(op6, 6);
183 ASSERT_TRUE(op_map.contains(*op1));
184 ASSERT_TRUE(op_map.contains(*op2));
185 ASSERT_TRUE(op_map.contains(*op3));
186 ASSERT_TRUE(op_map.contains(*op4));
187 ASSERT_TRUE(op_map.contains(*op5));
188 ASSERT_TRUE(op_map.contains(*op6));
189 op_map.erase(op6);
190 op_map.erase(op3);
191 op_map.erase(op1);
192 ASSERT_FALSE(op_map.contains(*op1));
193 ASSERT_FALSE(op_map.contains(*op3));
194 ASSERT_FALSE(op_map.contains(*op6));
195 op_map.insert(op1, 1);
196 ASSERT_TRUE(op_map.contains(*op1));
197 c10::optional<int> o1 = op_map.find(*op1);
198 ASSERT_TRUE(o1.has_value());
199 c10::optional<int> o2 = op_map.find(*op2);
200 ASSERT_TRUE(o2.has_value());
201 c10::optional<int> o3 = op_map.find(*op3);
202 ASSERT_FALSE(o3.has_value());
203 c10::optional<int> o4 = op_map.find(*op4);
204 ASSERT_TRUE(o4.has_value());
205 c10::optional<int> o5 = op_map.find(*op5);
206 ASSERT_TRUE(o5.has_value());
207 c10::optional<int> o6 = op_map.find(*op6);
208 ASSERT_FALSE(o6.has_value());
209}
210
211} // namespace jit
212} // namespace torch
213