1 | #include <iostream> |
2 | #include <sstream> |
3 | #include <string> |
4 | |
5 | #include <gtest/gtest.h> |
6 | |
7 | #include <test/cpp/jit/test_utils.h> |
8 | #include <torch/csrc/jit/ir/irparser.h> |
9 | #include <torch/csrc/jit/runtime/graph_iterator.h> |
10 | #include <torch/jit.h> |
11 | #include <torch/script.h> |
12 | #include <torch/torch.h> |
13 | |
14 | namespace torch { |
15 | namespace jit { |
16 | |
17 | /** |
18 | * Inverts an unordered map. |
19 | */ |
20 | template <typename K, typename V> |
21 | std::unordered_map<V, K> invert_map(std::unordered_map<K, V>& map) { |
22 | std::unordered_map<V, K> inverted; |
23 | std::for_each(map.begin(), map.end(), [&inverted](const std::pair<K, V>& p) { |
24 | inverted.insert(std::make_pair(p.second, p.first)); |
25 | }); |
26 | return inverted; |
27 | } |
28 | |
29 | /** |
30 | * Traverses the graph using the DepthFirstGraphNodeIterator and |
31 | * returns an array containing the original names in the string |
32 | * graph. |
33 | */ |
34 | std::vector<std::string> traverse_depth_first( |
35 | std::string graph_string, |
36 | int max_count = 100) { |
37 | auto graph = std::make_shared<Graph>(); |
38 | std::unordered_map<std::string, Value*> vmap; |
39 | torch::jit::parseIR(graph_string, graph.get(), vmap); |
40 | auto get_name = invert_map(vmap); |
41 | |
42 | std::vector<std::string> result; |
43 | DepthFirstGraphNodeIterator graph_it(graph); |
44 | Node* node = graph_it.next(); |
45 | int count = 0; |
46 | while (node && count < max_count) { |
47 | std::stringstream buffer; |
48 | std::vector<const torch::jit::Node*> vec; |
49 | node->print(buffer, 0, &vec, false, true, true, false); |
50 | result.push_back(buffer.str()); |
51 | node = graph_it.next(); |
52 | ++count; |
53 | } |
54 | return result; |
55 | } |
56 | |
57 | /** Checks that the iteration order matches the expected/provided order. */ |
58 | void assert_ordering( |
59 | std::vector<std::string> actual, |
60 | std::initializer_list<std::string> expected_list) { |
61 | auto expected = std::vector<std::string>(expected_list); |
62 | ASSERT_EQ(expected.size(), actual.size()) |
63 | << "Got " << actual.size() << " elements (" << actual << ")" |
64 | << " expected " << expected.size() << " elements (" << expected << ")" ; |
65 | for (unsigned i = 0; i < expected.size(); i++) { |
66 | ASSERT_EQ(expected[i], actual[i]) |
67 | << "Difference at index " << i << " in " << actual << " (expected " |
68 | << actual << ")" ; |
69 | } |
70 | } |
71 | |
72 | TEST(GraphIteratorTest, ConstantReturnGraph) { |
73 | const auto graph_string = R"IR( |
74 | graph(): |
75 | %1 : int = prim::Constant[value=0]() |
76 | return (%1))IR" ; |
77 | auto graph = std::make_shared<Graph>(); |
78 | torch::jit::parseIR(graph_string, graph.get()); |
79 | DepthFirstGraphNodeIterator graph_it(graph); |
80 | ASSERT_EQ(graph_it.next()->kind(), prim::Constant); |
81 | ASSERT_EQ(graph_it.next(), nullptr); |
82 | } |
83 | |
84 | TEST(GraphIteratorTest, GraphWithParameters) { |
85 | const auto graph_string = R"IR( |
86 | graph(%0 : Double(2)): |
87 | %1 : int = prim::Constant[value=0]() |
88 | return (%0))IR" ; |
89 | auto ordering = traverse_depth_first(graph_string); |
90 | assert_ordering(ordering, {"%1 : int = prim::Constant[value=0]()" }); |
91 | } |
92 | |
93 | TEST(GraphIteratorTest, GraphWithIf) { |
94 | const auto graph_string = R"IR( |
95 | graph(%a : Tensor): |
96 | %a : int = prim::Constant[value=30]() |
97 | %b : int = prim::Constant[value=10]() |
98 | %c : bool = aten::Bool(%a) |
99 | %d : int = prim::If(%c) |
100 | block0(): |
101 | -> (%a) |
102 | block1(): |
103 | -> (%b) |
104 | %e : int = prim::Constant[value=20]() |
105 | return (%d) |
106 | )IR" ; |
107 | auto ordering = traverse_depth_first(graph_string); |
108 | assert_ordering( |
109 | ordering, |
110 | {"%1 : int = prim::Constant[value=30]()" , |
111 | "%2 : int = prim::Constant[value=10]()" , |
112 | "%3 : bool = aten::Bool(%1)" , |
113 | "%4 : int = prim::If(%3)" , |
114 | "%5 : int = prim::Constant[value=20]()" }); |
115 | } |
116 | |
117 | TEST(GraphIteratorTest, GraphWithNestedIf) { |
118 | const auto graph_string = R"IR( |
119 | graph(%a.1 : Tensor, |
120 | %b.1 : Tensor): |
121 | %2 : int = prim::Constant[value=10]() |
122 | %3 : int = prim::Constant[value=20]() |
123 | %4 : int = prim::Constant[value=30]() |
124 | %5 : int = prim::Constant[value=40]() |
125 | %6 : bool = aten::Bool(%a.1) |
126 | %7 : int = prim::If(%6) |
127 | block0(): |
128 | %8 : bool = aten::Bool(%b.1) |
129 | %9 : int = prim::If(%8) |
130 | block0(): |
131 | -> (%2) |
132 | block1(): |
133 | -> (%3) |
134 | -> (%9) |
135 | block1(): |
136 | %10 : bool = aten::Bool(%b.1) |
137 | %11 : int = prim::If(%10) |
138 | block0(): |
139 | -> (%4) |
140 | block1(): |
141 | -> (%5) |
142 | -> (%11) |
143 | %8 : bool = aten::Bool(%b.1) |
144 | %9 : int = prim::If(%8) |
145 | block0(): |
146 | -> (%2) |
147 | block1(): |
148 | -> (%3) |
149 | %10 : bool = aten::Bool(%b.1) |
150 | %11 : int = prim::If(%10) |
151 | block0(): |
152 | -> (%4) |
153 | block1(): |
154 | -> (%5) |
155 | return (%7) |
156 | )IR" ; |
157 | auto ordering = traverse_depth_first(graph_string); |
158 | assert_ordering( |
159 | ordering, |
160 | {"%2 : int = prim::Constant[value=10]()" , |
161 | "%3 : int = prim::Constant[value=20]()" , |
162 | "%4 : int = prim::Constant[value=30]()" , |
163 | "%5 : int = prim::Constant[value=40]()" , |
164 | "%6 : bool = aten::Bool(%a.1)" , |
165 | "%7 : int = prim::If(%6)" , |
166 | "%8 : bool = aten::Bool(%b.1)" , |
167 | "%9 : int = prim::If(%8)" , |
168 | "%10 : bool = aten::Bool(%b.1)" , |
169 | "%11 : int = prim::If(%10)" , |
170 | "%12 : bool = aten::Bool(%b.1)" , |
171 | "%13 : int = prim::If(%12)" , |
172 | "%14 : bool = aten::Bool(%b.1)" , |
173 | "%15 : int = prim::If(%14)" }); |
174 | } |
175 | |
176 | TEST(GraphIteratorTest, GraphWithLoop) { |
177 | const auto graph_string = R"IR( |
178 | graph(%a.1 : Tensor): |
179 | %1 : bool = prim::Constant[value=1]() |
180 | %2 : int = prim::Constant[value=10]() |
181 | %3 : int = prim::Constant[value=1]() |
182 | %4 : Tensor = prim::Loop(%2, %1, %a.1) |
183 | block0(%i : int, %b.9 : Tensor): |
184 | %5 : Tensor = aten::add_(%b.9, %3, %3) |
185 | -> (%1, %5) |
186 | %6 : Tensor = prim::Loop(%2, %1, %a.1) |
187 | block0(%i : int, %b.9 : Tensor): |
188 | -> (%1, %4) |
189 | return (%6) |
190 | )IR" ; |
191 | auto ordering = traverse_depth_first(graph_string); |
192 | assert_ordering( |
193 | ordering, |
194 | {"%1 : bool = prim::Constant[value=1]()" , |
195 | "%2 : int = prim::Constant[value=10]()" , |
196 | "%3 : int = prim::Constant[value=1]()" , |
197 | "%4 : Tensor = prim::Loop(%2, %1, %a.1)" , |
198 | "%7 : Tensor = aten::add_(%b.10, %3, %3)" , |
199 | "%8 : Tensor = prim::Loop(%2, %1, %a.1)" }); |
200 | } |
201 | |
202 | } // namespace jit |
203 | } // namespace torch |
204 | |