1 | #include <gtest/gtest.h> |
2 | |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | #include <torch/csrc/jit/ir/irparser.h> |
5 | #include <torch/csrc/jit/testing/file_check.h> |
6 | |
7 | #include <sstream> |
8 | #include <string> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | |
13 | /** \brief Parse IR from \p S, print the parsed graph and verify that the output |
14 | * string matches the original string. |
15 | * |
16 | * The function is sensitive to value naming and whitespace, so it should be |
17 | * used with care. Nevertheless, it helps to keep tests more compact. |
18 | */ |
19 | static void checkRoundtrip(const std::string& s) { |
20 | auto graph = std::make_shared<Graph>(); |
21 | parseIR(s, &*graph); |
22 | std::ostringstream ss; |
23 | ss << *graph; |
24 | std::string parsed = ss.str(); |
25 | |
26 | // Skip whitespace in the beginning of the input string. |
27 | int i = 0; |
28 | for (char c : s) { |
29 | if (!isspace(c)) { |
30 | break; |
31 | } |
32 | i++; |
33 | } |
34 | std::string original = s.substr(i, s.size()); |
35 | if (original != parsed) { |
36 | std::cerr << "Input:" << std::endl << original << std::endl; |
37 | std::cerr << "Parsed:" << std::endl << parsed << std::endl; |
38 | } |
39 | AT_ASSERT(original == parsed); |
40 | } |
41 | |
42 | TEST(IRParserTest, Basic) { |
43 | auto graph = std::make_shared<Graph>(); |
44 | std::unordered_map<std::string, Value*> vmap; |
45 | parseIR( |
46 | R"IR( |
47 | graph(%0 : Tensor, %1 : Tensor): |
48 | %2 : Tensor = foo::add(%0, %1) |
49 | %res, %3 = foo::mul(%0, %2) |
50 | %x, %y = foo::combine(%res, %2, %3) |
51 | return (%x, %y, %res))IR" , |
52 | &*graph, |
53 | vmap); |
54 | |
55 | AT_ASSERT(graph->inputs().size() == 2); |
56 | AT_ASSERT(graph->outputs().size() == 3); |
57 | Value* x = graph->outputs()[0]; |
58 | Value* y = graph->outputs()[1]; |
59 | Value* res = graph->outputs()[2]; |
60 | Value* t0 = graph->inputs()[0]; |
61 | Value* t1 = graph->inputs()[1]; |
62 | AT_ASSERT(vmap["x" ] == x); |
63 | AT_ASSERT(vmap["y" ] == y); |
64 | AT_ASSERT(vmap["res" ] == res); |
65 | AT_ASSERT(vmap["0" ] == t0); |
66 | AT_ASSERT(vmap["1" ] == t1); |
67 | AT_ASSERT(x->node() == y->node()); |
68 | Node* comb = x->node(); |
69 | Value* t2 = comb->inputs()[1]; |
70 | Value* t3 = comb->inputs()[2]; |
71 | AT_ASSERT(vmap["2" ] == t2); |
72 | AT_ASSERT(vmap["3" ] == t3); |
73 | AT_ASSERT(comb->kind().toQualString() == std::string("foo::combine" )); |
74 | AT_ASSERT(comb->outputs() == std::vector<Value*>({x, y})); |
75 | AT_ASSERT(comb->inputs() == std::vector<Value*>({res, t2, t3})); |
76 | Node* mul = res->node(); |
77 | AT_ASSERT(mul->kind().toQualString() == std::string("foo::mul" )); |
78 | AT_ASSERT(mul->inputs() == std::vector<Value*>({t0, t2})); |
79 | AT_ASSERT(mul->outputs() == std::vector<Value*>({res, t3})); |
80 | Node* add = t2->node(); |
81 | AT_ASSERT(add->kind().toQualString() == std::string("foo::add" )); |
82 | AT_ASSERT(add->inputs() == std::vector<Value*>({t0, t1})); |
83 | AT_ASSERT(add->outputs() == std::vector<Value*>({t2})); |
84 | } |
85 | |
86 | TEST(IRParserTest, NestedBlock) { |
87 | checkRoundtrip(R"IR( |
88 | graph(): |
89 | %0 : Tensor = a::a() |
90 | block0(): |
91 | %1 : Tensor = b::b() |
92 | block0(): |
93 | %2 : Tensor = c::c() |
94 | -> () |
95 | -> () |
96 | %3 : Tensor = d::d() |
97 | return (%3) |
98 | )IR" ); |
99 | } |
100 | |
101 | TEST(IRParserTest, If) { |
102 | checkRoundtrip(R"IR( |
103 | graph(%0 : Tensor, |
104 | %1 : Tensor, |
105 | %2 : Tensor): |
106 | %3 : int = prim::Constant[value=1]() |
107 | %4 : Tensor = aten::add(%0, %1, %3) |
108 | %5 : Tensor = prim::If(%2) |
109 | block0(): |
110 | %6 : int = prim::Constant[value=1]() |
111 | %7 : Tensor = aten::add(%1, %3, %6) |
112 | %8 : int = prim::Constant[value=1]() |
113 | %9 : Tensor = aten::add(%7, %3, %8) |
114 | -> (%9) |
115 | %10 : int = prim::Constant[value=1]() |
116 | %11 : Tensor = aten::add(%5, %3, %10) |
117 | return (%11) |
118 | )IR" ); |
119 | } |
120 | |
121 | TEST(IRParserTest, If2) { |
122 | checkRoundtrip(R"IR( |
123 | graph(%0 : Tensor, |
124 | %1 : Tensor, |
125 | %2 : Tensor): |
126 | %3 : int = prim::Constant[value=-1]() |
127 | %4 : Tensor = aten::add(%0, %1, %3) |
128 | %5 : Tensor = prim::If(%2) |
129 | block0(): |
130 | %6 : int = prim::Constant[value=1]() |
131 | %7 : Tensor = aten::add(%1, %3, %6) |
132 | %8 : int = prim::Constant[value=1]() |
133 | %9 : Tensor = aten::add(%7, %3, %8) |
134 | -> (%9) |
135 | %10 : int = prim::Constant[value=-987]() |
136 | %11 : Tensor = aten::add(%5, %3, %10) |
137 | return (%11) |
138 | )IR" ); |
139 | } |
140 | |
141 | TEST(IRParserTest, InferredTypeIsTensor) { |
142 | auto graph = std::make_shared<Graph>(); |
143 | parseIR( |
144 | R"IR( |
145 | graph(%a): |
146 | return (%a))IR" , |
147 | &*graph); |
148 | AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(*TensorType::get())); |
149 | } |
150 | |
151 | TEST(IRParserTest, ValueReuse) { |
152 | // Check that parser correctly handles values reusing the same name. |
153 | auto graph = std::make_shared<Graph>(); |
154 | parseIR( |
155 | R"IR( |
156 | graph(%x): |
157 | %x = a::a(%x) |
158 | %x = b::b(%x) |
159 | return (%x))IR" , |
160 | &*graph); |
161 | Value* x0 = graph->inputs()[0]; |
162 | Value* x2 = graph->outputs()[0]; |
163 | Node* b = x2->node(); |
164 | Value* x1 = b->inputs()[0]; |
165 | Node* a = x1->node(); |
166 | AT_ASSERT(a->inputs() == std::vector<Value*>({x0})); |
167 | AT_ASSERT(a->outputs() == std::vector<Value*>({x1})); |
168 | AT_ASSERT(b->inputs() == std::vector<Value*>({x1})); |
169 | AT_ASSERT(b->outputs() == std::vector<Value*>({x2})); |
170 | } |
171 | |
172 | TEST(IRParserTest, Attributes) { |
173 | // Check that parser handles attributes and types. |
174 | checkRoundtrip( |
175 | R"IR( |
176 | graph(%0 : Tensor, |
177 | %1 : Tensor, |
178 | %2 : Tensor): |
179 | %3 : int, %4 : Tensor = qqq::qqq[i_asdf=2, f_asdf=3., s_asdf="hello", ss_asdf=["hello world", "bye bye"]](%0) |
180 | %5 : int, %6 : Tensor = ppp::ppp[i_asdf=2, f_asdf=3., s_asdf="\"\"\"\"\nhe\"llo", q=[3, 2, 4]](%0) |
181 | %7 : float = vvv::vvv[s_asdf="hello"](%0) |
182 | %8 : string = z::z() |
183 | return (%7) |
184 | )IR" ); |
185 | } |
186 | |
187 | TEST(IRParserTest, OptionalTypes) { |
188 | checkRoundtrip( |
189 | R"IR( |
190 | graph(%0 : Tensor, |
191 | %1 : Tensor, |
192 | %2 : Tensor): |
193 | %3 : int? = prim::Constant() |
194 | return (%3) |
195 | )IR" ); |
196 | } |
197 | |
198 | TEST(IRParserTest, StarTensor) { |
199 | checkRoundtrip( |
200 | R"IR( |
201 | graph(%0 : Tensor, |
202 | %1 : Tensor, |
203 | %2 : Tensor): |
204 | %3 : Float(*, *, *) = prim::Constant() |
205 | return (%3) |
206 | )IR" ); |
207 | } |
208 | |
209 | TEST(IRParserTest, UnshapedTensor) { |
210 | checkRoundtrip( |
211 | R"IR( |
212 | graph(%0 : Tensor, |
213 | %1 : Tensor, |
214 | %2 : Tensor): |
215 | %3 : Long() = prim::Constant() |
216 | return (%3) |
217 | )IR" ); |
218 | } |
219 | |
220 | TEST(IRParserTest, ShapedTensor) { |
221 | checkRoundtrip( |
222 | R"IR( |
223 | graph(%0 : Tensor, |
224 | %1 : Tensor, |
225 | %2 : Tensor): |
226 | %3 : Double(4, 4, 5) = prim::Constant() |
227 | return (%3) |
228 | )IR" ); |
229 | } |
230 | |
231 | TEST(IRParserTest, NestedContrainer) { |
232 | checkRoundtrip( |
233 | R"IR( |
234 | graph(): |
235 | %0 : float[] = prim::Constant[value=[1., 2., 3.]]() |
236 | %1 : str[] = prim::Constant[value=["ab", "cd", "ef"]]() |
237 | %2 : (float[], str[]) = prim::TupleConstruct(%0, %1) |
238 | return (%2) |
239 | )IR" ); |
240 | } |
241 | |
242 | TEST(IRParserTest, MalformedShapeAnnotation) { |
243 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
244 | EXPECT_ANY_THROW(checkRoundtrip( |
245 | R"IR( |
246 | graph(%0 : Tensor, |
247 | %1 : Tensor, |
248 | %2 : Tensor): |
249 | %3 : Double(4!, 4, 5) = prim::Constant() |
250 | return (%3) |
251 | )IR" )); |
252 | } |
253 | |
254 | TEST(IRParserTest, FileCheck) { |
255 | auto graph = std::make_shared<Graph>(); |
256 | const std::string& text = |
257 | R"IR( |
258 | graph(%a): |
259 | # CHECK: return |
260 | return (%a))IR" ; |
261 | |
262 | parseIR(text, &*graph); |
263 | AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(*TensorType::get())); |
264 | torch::jit::testing::FileCheck().run(text, *graph); |
265 | } |
266 | |
267 | TEST(IRParserTest, Strides) { |
268 | auto graph = std::make_shared<Graph>(); |
269 | std::unordered_map<std::string, Value*> vmap; |
270 | parseIR( |
271 | R"IR( |
272 | graph(%a : Float(4, 5), |
273 | %b : Float(4, 5, strides=[5, 1]), |
274 | %c : Double(*, *)): |
275 | return (%a) |
276 | )IR" , |
277 | &*graph, |
278 | vmap); |
279 | Value* a = graph->inputs()[0]; |
280 | Value* b = graph->inputs()[1]; |
281 | Value* c = graph->inputs()[2]; |
282 | |
283 | auto a_type = a->type()->cast<TensorType>(); |
284 | auto a_sizes = *a_type->sizes().concrete_sizes(); |
285 | auto a_strides = a_type->strides().concrete_sizes(); |
286 | AT_ASSERT(a_sizes[0] == 4 && a_sizes[1] == 5); |
287 | AT_ASSERT(a_strides == c10::nullopt); |
288 | |
289 | auto b_type = b->type()->cast<TensorType>(); |
290 | auto b_sizes = *b_type->sizes().concrete_sizes(); |
291 | auto b_strides = *(b_type->strides().sizes()); |
292 | AT_ASSERT(b_sizes[0] == 4 && b_sizes[1] == 5); |
293 | AT_ASSERT(*b_strides[0] == 5 && *b_strides[1] == 1); |
294 | |
295 | auto c_type = c->type()->cast<TensorType>(); |
296 | AT_ASSERT(*c_type->sizes().size() == 2); |
297 | AT_ASSERT(c_type->sizes().concrete_sizes() == c10::nullopt); |
298 | AT_ASSERT(c_type->strides().concrete_sizes() == c10::nullopt); |
299 | } |
300 | |
301 | TEST(IRParserTest, MalformedStrides) { |
302 | auto graph = std::make_shared<Graph>(); |
303 | std::unordered_map<std::string, Value*> vmap; |
304 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
305 | EXPECT_ANY_THROW(parseIR( |
306 | R"IR( |
307 | graph(%a : Float(4, strides=[5], 5)): |
308 | return (%a) |
309 | )IR" , |
310 | &*graph, |
311 | vmap)); |
312 | } |
313 | |
314 | TEST(IRParserTest, TensorShapes) { |
315 | checkRoundtrip( |
316 | R"IR( |
317 | graph(%a : Float(4, 5), |
318 | %b : Float(4, 5, strides=[5, 1]), |
319 | %c : Double(*, *)): |
320 | return (%a) |
321 | )IR" ); |
322 | } |
323 | |
324 | TEST(IRParserTest, DeviceAndRequiresGradTensors) { |
325 | checkRoundtrip( |
326 | R"IR( |
327 | graph(%a : Float(*, *, device=cpu), |
328 | %b : Float(*, *, requires_grad=1), |
329 | %c : Long(5, 10, requires_grad=1, device=cpu), |
330 | %d : Float(5, requires_grad=0, device=cuda:2), |
331 | %e : Long(4, 3, 1, strides=[6, 2, 1], requires_grad=0, device=cuda:1), |
332 | %f : Float(), |
333 | %g : Float(device=cpu), |
334 | %h : Float(requires_grad=1), |
335 | %i : Float(requires_grad=0, device=cuda:1), |
336 | %j : Double(*, *, requires_grad=0)): |
337 | return (%a) |
338 | )IR" ); |
339 | } |
340 | |
341 | TEST(IRParserTest, ListConstant) { |
342 | auto graph = std::make_shared<Graph>(); |
343 | parseIR( |
344 | R"IR( |
345 | graph(): |
346 | %d : int[] = prim::Constant[value=[1,2,3]]() |
347 | return (%d) |
348 | )IR" , |
349 | &*graph); |
350 | Node* n = graph->outputs()[0]->node(); |
351 | AT_ASSERT(n->kind() == prim::Constant); |
352 | AT_ASSERT(n->kindOf(attr::value) == AttributeKind::ival); |
353 | const auto& genericList = n->ival(attr::value).toList(); |
354 | std::vector<int> int_vals; |
355 | // NOLINTNEXTLINE(performance-implicit-conversion-in-loop) |
356 | for (const IValue& ival : genericList) { |
357 | int_vals.push_back(ival.toInt()); |
358 | } |
359 | AT_ASSERT(int_vals.size() == 3); |
360 | AT_ASSERT(int_vals[0] == 1 && int_vals[1] == 2 && int_vals[2] == 3); |
361 | } |
362 | |
363 | TEST(IRParserTest, PartialStarTensor) { |
364 | checkRoundtrip( |
365 | R"IR( |
366 | graph(%x : Float(10, *, 10)): |
367 | return (%x) |
368 | )IR" ); |
369 | } |
370 | |
371 | TEST(IRParserTest, ComplexTensorAttributes) { |
372 | checkRoundtrip( |
373 | R"IR( |
374 | graph(%x : Double(*, 200, *, requires_grad=1, device=cuda:1), |
375 | %b : Float(5, *, requires_grad=1), |
376 | %c : Long(*, 10, device=cpu)): |
377 | return (%x) |
378 | )IR" ); |
379 | } |
380 | } // namespace jit |
381 | } // namespace torch |
382 | |