1#include <gtest/gtest.h>
2
3#include <torch/csrc/jit/ir/ir.h>
4#include <torch/csrc/jit/ir/irparser.h>
5#include <torch/csrc/jit/testing/file_check.h>
6
7#include <sstream>
8#include <string>
9
10namespace torch {
11namespace jit {
12
13/** \brief Parse IR from \p S, print the parsed graph and verify that the output
14 * string matches the original string.
15 *
16 * The function is sensitive to value naming and whitespace, so it should be
17 * used with care. Nevertheless, it helps to keep tests more compact.
18 */
19static void checkRoundtrip(const std::string& s) {
20 auto graph = std::make_shared<Graph>();
21 parseIR(s, &*graph);
22 std::ostringstream ss;
23 ss << *graph;
24 std::string parsed = ss.str();
25
26 // Skip whitespace in the beginning of the input string.
27 int i = 0;
28 for (char c : s) {
29 if (!isspace(c)) {
30 break;
31 }
32 i++;
33 }
34 std::string original = s.substr(i, s.size());
35 if (original != parsed) {
36 std::cerr << "Input:" << std::endl << original << std::endl;
37 std::cerr << "Parsed:" << std::endl << parsed << std::endl;
38 }
39 AT_ASSERT(original == parsed);
40}
41
42TEST(IRParserTest, Basic) {
43 auto graph = std::make_shared<Graph>();
44 std::unordered_map<std::string, Value*> vmap;
45 parseIR(
46 R"IR(
47graph(%0 : Tensor, %1 : Tensor):
48 %2 : Tensor = foo::add(%0, %1)
49 %res, %3 = foo::mul(%0, %2)
50 %x, %y = foo::combine(%res, %2, %3)
51 return (%x, %y, %res))IR",
52 &*graph,
53 vmap);
54
55 AT_ASSERT(graph->inputs().size() == 2);
56 AT_ASSERT(graph->outputs().size() == 3);
57 Value* x = graph->outputs()[0];
58 Value* y = graph->outputs()[1];
59 Value* res = graph->outputs()[2];
60 Value* t0 = graph->inputs()[0];
61 Value* t1 = graph->inputs()[1];
62 AT_ASSERT(vmap["x"] == x);
63 AT_ASSERT(vmap["y"] == y);
64 AT_ASSERT(vmap["res"] == res);
65 AT_ASSERT(vmap["0"] == t0);
66 AT_ASSERT(vmap["1"] == t1);
67 AT_ASSERT(x->node() == y->node());
68 Node* comb = x->node();
69 Value* t2 = comb->inputs()[1];
70 Value* t3 = comb->inputs()[2];
71 AT_ASSERT(vmap["2"] == t2);
72 AT_ASSERT(vmap["3"] == t3);
73 AT_ASSERT(comb->kind().toQualString() == std::string("foo::combine"));
74 AT_ASSERT(comb->outputs() == std::vector<Value*>({x, y}));
75 AT_ASSERT(comb->inputs() == std::vector<Value*>({res, t2, t3}));
76 Node* mul = res->node();
77 AT_ASSERT(mul->kind().toQualString() == std::string("foo::mul"));
78 AT_ASSERT(mul->inputs() == std::vector<Value*>({t0, t2}));
79 AT_ASSERT(mul->outputs() == std::vector<Value*>({res, t3}));
80 Node* add = t2->node();
81 AT_ASSERT(add->kind().toQualString() == std::string("foo::add"));
82 AT_ASSERT(add->inputs() == std::vector<Value*>({t0, t1}));
83 AT_ASSERT(add->outputs() == std::vector<Value*>({t2}));
84}
85
86TEST(IRParserTest, NestedBlock) {
87 checkRoundtrip(R"IR(
88graph():
89 %0 : Tensor = a::a()
90 block0():
91 %1 : Tensor = b::b()
92 block0():
93 %2 : Tensor = c::c()
94 -> ()
95 -> ()
96 %3 : Tensor = d::d()
97 return (%3)
98)IR");
99}
100
101TEST(IRParserTest, If) {
102 checkRoundtrip(R"IR(
103graph(%0 : Tensor,
104 %1 : Tensor,
105 %2 : Tensor):
106 %3 : int = prim::Constant[value=1]()
107 %4 : Tensor = aten::add(%0, %1, %3)
108 %5 : Tensor = prim::If(%2)
109 block0():
110 %6 : int = prim::Constant[value=1]()
111 %7 : Tensor = aten::add(%1, %3, %6)
112 %8 : int = prim::Constant[value=1]()
113 %9 : Tensor = aten::add(%7, %3, %8)
114 -> (%9)
115 %10 : int = prim::Constant[value=1]()
116 %11 : Tensor = aten::add(%5, %3, %10)
117 return (%11)
118)IR");
119}
120
121TEST(IRParserTest, If2) {
122 checkRoundtrip(R"IR(
123graph(%0 : Tensor,
124 %1 : Tensor,
125 %2 : Tensor):
126 %3 : int = prim::Constant[value=-1]()
127 %4 : Tensor = aten::add(%0, %1, %3)
128 %5 : Tensor = prim::If(%2)
129 block0():
130 %6 : int = prim::Constant[value=1]()
131 %7 : Tensor = aten::add(%1, %3, %6)
132 %8 : int = prim::Constant[value=1]()
133 %9 : Tensor = aten::add(%7, %3, %8)
134 -> (%9)
135 %10 : int = prim::Constant[value=-987]()
136 %11 : Tensor = aten::add(%5, %3, %10)
137 return (%11)
138)IR");
139}
140
141TEST(IRParserTest, InferredTypeIsTensor) {
142 auto graph = std::make_shared<Graph>();
143 parseIR(
144 R"IR(
145graph(%a):
146 return (%a))IR",
147 &*graph);
148 AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(*TensorType::get()));
149}
150
151TEST(IRParserTest, ValueReuse) {
152 // Check that parser correctly handles values reusing the same name.
153 auto graph = std::make_shared<Graph>();
154 parseIR(
155 R"IR(
156graph(%x):
157 %x = a::a(%x)
158 %x = b::b(%x)
159 return (%x))IR",
160 &*graph);
161 Value* x0 = graph->inputs()[0];
162 Value* x2 = graph->outputs()[0];
163 Node* b = x2->node();
164 Value* x1 = b->inputs()[0];
165 Node* a = x1->node();
166 AT_ASSERT(a->inputs() == std::vector<Value*>({x0}));
167 AT_ASSERT(a->outputs() == std::vector<Value*>({x1}));
168 AT_ASSERT(b->inputs() == std::vector<Value*>({x1}));
169 AT_ASSERT(b->outputs() == std::vector<Value*>({x2}));
170}
171
172TEST(IRParserTest, Attributes) {
173 // Check that parser handles attributes and types.
174 checkRoundtrip(
175 R"IR(
176graph(%0 : Tensor,
177 %1 : Tensor,
178 %2 : Tensor):
179 %3 : int, %4 : Tensor = qqq::qqq[i_asdf=2, f_asdf=3., s_asdf="hello", ss_asdf=["hello world", "bye bye"]](%0)
180 %5 : int, %6 : Tensor = ppp::ppp[i_asdf=2, f_asdf=3., s_asdf="\"\"\"\"\nhe\"llo", q=[3, 2, 4]](%0)
181 %7 : float = vvv::vvv[s_asdf="hello"](%0)
182 %8 : string = z::z()
183 return (%7)
184)IR");
185}
186
187TEST(IRParserTest, OptionalTypes) {
188 checkRoundtrip(
189 R"IR(
190graph(%0 : Tensor,
191 %1 : Tensor,
192 %2 : Tensor):
193 %3 : int? = prim::Constant()
194 return (%3)
195)IR");
196}
197
198TEST(IRParserTest, StarTensor) {
199 checkRoundtrip(
200 R"IR(
201graph(%0 : Tensor,
202 %1 : Tensor,
203 %2 : Tensor):
204 %3 : Float(*, *, *) = prim::Constant()
205 return (%3)
206)IR");
207}
208
209TEST(IRParserTest, UnshapedTensor) {
210 checkRoundtrip(
211 R"IR(
212graph(%0 : Tensor,
213 %1 : Tensor,
214 %2 : Tensor):
215 %3 : Long() = prim::Constant()
216 return (%3)
217)IR");
218}
219
220TEST(IRParserTest, ShapedTensor) {
221 checkRoundtrip(
222 R"IR(
223graph(%0 : Tensor,
224 %1 : Tensor,
225 %2 : Tensor):
226 %3 : Double(4, 4, 5) = prim::Constant()
227 return (%3)
228)IR");
229}
230
231TEST(IRParserTest, NestedContrainer) {
232 checkRoundtrip(
233 R"IR(
234graph():
235 %0 : float[] = prim::Constant[value=[1., 2., 3.]]()
236 %1 : str[] = prim::Constant[value=["ab", "cd", "ef"]]()
237 %2 : (float[], str[]) = prim::TupleConstruct(%0, %1)
238 return (%2)
239)IR");
240}
241
242TEST(IRParserTest, MalformedShapeAnnotation) {
243 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
244 EXPECT_ANY_THROW(checkRoundtrip(
245 R"IR(
246graph(%0 : Tensor,
247 %1 : Tensor,
248 %2 : Tensor):
249 %3 : Double(4!, 4, 5) = prim::Constant()
250 return (%3)
251)IR"));
252}
253
254TEST(IRParserTest, FileCheck) {
255 auto graph = std::make_shared<Graph>();
256 const std::string& text =
257 R"IR(
258 graph(%a):
259 # CHECK: return
260 return (%a))IR";
261
262 parseIR(text, &*graph);
263 AT_ASSERT(graph->inputs()[0]->type()->isSubtypeOf(*TensorType::get()));
264 torch::jit::testing::FileCheck().run(text, *graph);
265}
266
267TEST(IRParserTest, Strides) {
268 auto graph = std::make_shared<Graph>();
269 std::unordered_map<std::string, Value*> vmap;
270 parseIR(
271 R"IR(
272graph(%a : Float(4, 5),
273 %b : Float(4, 5, strides=[5, 1]),
274 %c : Double(*, *)):
275 return (%a)
276)IR",
277 &*graph,
278 vmap);
279 Value* a = graph->inputs()[0];
280 Value* b = graph->inputs()[1];
281 Value* c = graph->inputs()[2];
282
283 auto a_type = a->type()->cast<TensorType>();
284 auto a_sizes = *a_type->sizes().concrete_sizes();
285 auto a_strides = a_type->strides().concrete_sizes();
286 AT_ASSERT(a_sizes[0] == 4 && a_sizes[1] == 5);
287 AT_ASSERT(a_strides == c10::nullopt);
288
289 auto b_type = b->type()->cast<TensorType>();
290 auto b_sizes = *b_type->sizes().concrete_sizes();
291 auto b_strides = *(b_type->strides().sizes());
292 AT_ASSERT(b_sizes[0] == 4 && b_sizes[1] == 5);
293 AT_ASSERT(*b_strides[0] == 5 && *b_strides[1] == 1);
294
295 auto c_type = c->type()->cast<TensorType>();
296 AT_ASSERT(*c_type->sizes().size() == 2);
297 AT_ASSERT(c_type->sizes().concrete_sizes() == c10::nullopt);
298 AT_ASSERT(c_type->strides().concrete_sizes() == c10::nullopt);
299}
300
301TEST(IRParserTest, MalformedStrides) {
302 auto graph = std::make_shared<Graph>();
303 std::unordered_map<std::string, Value*> vmap;
304 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
305 EXPECT_ANY_THROW(parseIR(
306 R"IR(
307graph(%a : Float(4, strides=[5], 5)):
308 return (%a)
309)IR",
310 &*graph,
311 vmap));
312}
313
314TEST(IRParserTest, TensorShapes) {
315 checkRoundtrip(
316 R"IR(
317graph(%a : Float(4, 5),
318 %b : Float(4, 5, strides=[5, 1]),
319 %c : Double(*, *)):
320 return (%a)
321)IR");
322}
323
324TEST(IRParserTest, DeviceAndRequiresGradTensors) {
325 checkRoundtrip(
326 R"IR(
327graph(%a : Float(*, *, device=cpu),
328 %b : Float(*, *, requires_grad=1),
329 %c : Long(5, 10, requires_grad=1, device=cpu),
330 %d : Float(5, requires_grad=0, device=cuda:2),
331 %e : Long(4, 3, 1, strides=[6, 2, 1], requires_grad=0, device=cuda:1),
332 %f : Float(),
333 %g : Float(device=cpu),
334 %h : Float(requires_grad=1),
335 %i : Float(requires_grad=0, device=cuda:1),
336 %j : Double(*, *, requires_grad=0)):
337 return (%a)
338)IR");
339}
340
341TEST(IRParserTest, ListConstant) {
342 auto graph = std::make_shared<Graph>();
343 parseIR(
344 R"IR(
345graph():
346 %d : int[] = prim::Constant[value=[1,2,3]]()
347 return (%d)
348)IR",
349 &*graph);
350 Node* n = graph->outputs()[0]->node();
351 AT_ASSERT(n->kind() == prim::Constant);
352 AT_ASSERT(n->kindOf(attr::value) == AttributeKind::ival);
353 const auto& genericList = n->ival(attr::value).toList();
354 std::vector<int> int_vals;
355 // NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
356 for (const IValue& ival : genericList) {
357 int_vals.push_back(ival.toInt());
358 }
359 AT_ASSERT(int_vals.size() == 3);
360 AT_ASSERT(int_vals[0] == 1 && int_vals[1] == 2 && int_vals[2] == 3);
361}
362
363TEST(IRParserTest, PartialStarTensor) {
364 checkRoundtrip(
365 R"IR(
366graph(%x : Float(10, *, 10)):
367 return (%x)
368)IR");
369}
370
371TEST(IRParserTest, ComplexTensorAttributes) {
372 checkRoundtrip(
373 R"IR(
374graph(%x : Double(*, 200, *, requires_grad=1, device=cuda:1),
375 %b : Float(5, *, requires_grad=1),
376 %c : Long(*, 10, device=cpu)):
377 return (%x)
378)IR");
379}
380} // namespace jit
381} // namespace torch
382