1#include <iostream>
2#include <sstream>
3#include <string>
4
5#include <gtest/gtest.h>
6
7#include <test/cpp/jit/test_utils.h>
8#include <torch/csrc/jit/ir/irparser.h>
9#include <torch/csrc/jit/runtime/graph_iterator.h>
10#include <torch/jit.h>
11#include <torch/script.h>
12#include <torch/torch.h>
13
14namespace torch {
15namespace jit {
16
17/**
18 * Inverts an unordered map.
19 */
20template <typename K, typename V>
21std::unordered_map<V, K> invert_map(std::unordered_map<K, V>& map) {
22 std::unordered_map<V, K> inverted;
23 std::for_each(map.begin(), map.end(), [&inverted](const std::pair<K, V>& p) {
24 inverted.insert(std::make_pair(p.second, p.first));
25 });
26 return inverted;
27}
28
29/**
30 * Traverses the graph using the DepthFirstGraphNodeIterator and
31 * returns an array containing the original names in the string
32 * graph.
33 */
34std::vector<std::string> traverse_depth_first(
35 std::string graph_string,
36 int max_count = 100) {
37 auto graph = std::make_shared<Graph>();
38 std::unordered_map<std::string, Value*> vmap;
39 torch::jit::parseIR(graph_string, graph.get(), vmap);
40 auto get_name = invert_map(vmap);
41
42 std::vector<std::string> result;
43 DepthFirstGraphNodeIterator graph_it(graph);
44 Node* node = graph_it.next();
45 int count = 0;
46 while (node && count < max_count) {
47 std::stringstream buffer;
48 std::vector<const torch::jit::Node*> vec;
49 node->print(buffer, 0, &vec, false, true, true, false);
50 result.push_back(buffer.str());
51 node = graph_it.next();
52 ++count;
53 }
54 return result;
55}
56
57/** Checks that the iteration order matches the expected/provided order. */
58void assert_ordering(
59 std::vector<std::string> actual,
60 std::initializer_list<std::string> expected_list) {
61 auto expected = std::vector<std::string>(expected_list);
62 ASSERT_EQ(expected.size(), actual.size())
63 << "Got " << actual.size() << " elements (" << actual << ")"
64 << " expected " << expected.size() << " elements (" << expected << ")";
65 for (unsigned i = 0; i < expected.size(); i++) {
66 ASSERT_EQ(expected[i], actual[i])
67 << "Difference at index " << i << " in " << actual << " (expected "
68 << actual << ")";
69 }
70}
71
72TEST(GraphIteratorTest, ConstantReturnGraph) {
73 const auto graph_string = R"IR(
74 graph():
75 %1 : int = prim::Constant[value=0]()
76 return (%1))IR";
77 auto graph = std::make_shared<Graph>();
78 torch::jit::parseIR(graph_string, graph.get());
79 DepthFirstGraphNodeIterator graph_it(graph);
80 ASSERT_EQ(graph_it.next()->kind(), prim::Constant);
81 ASSERT_EQ(graph_it.next(), nullptr);
82}
83
84TEST(GraphIteratorTest, GraphWithParameters) {
85 const auto graph_string = R"IR(
86 graph(%0 : Double(2)):
87 %1 : int = prim::Constant[value=0]()
88 return (%0))IR";
89 auto ordering = traverse_depth_first(graph_string);
90 assert_ordering(ordering, {"%1 : int = prim::Constant[value=0]()"});
91}
92
93TEST(GraphIteratorTest, GraphWithIf) {
94 const auto graph_string = R"IR(
95graph(%a : Tensor):
96 %a : int = prim::Constant[value=30]()
97 %b : int = prim::Constant[value=10]()
98 %c : bool = aten::Bool(%a)
99 %d : int = prim::If(%c)
100 block0():
101 -> (%a)
102 block1():
103 -> (%b)
104 %e : int = prim::Constant[value=20]()
105 return (%d)
106)IR";
107 auto ordering = traverse_depth_first(graph_string);
108 assert_ordering(
109 ordering,
110 {"%1 : int = prim::Constant[value=30]()",
111 "%2 : int = prim::Constant[value=10]()",
112 "%3 : bool = aten::Bool(%1)",
113 "%4 : int = prim::If(%3)",
114 "%5 : int = prim::Constant[value=20]()"});
115}
116
117TEST(GraphIteratorTest, GraphWithNestedIf) {
118 const auto graph_string = R"IR(
119graph(%a.1 : Tensor,
120 %b.1 : Tensor):
121 %2 : int = prim::Constant[value=10]()
122 %3 : int = prim::Constant[value=20]()
123 %4 : int = prim::Constant[value=30]()
124 %5 : int = prim::Constant[value=40]()
125 %6 : bool = aten::Bool(%a.1)
126 %7 : int = prim::If(%6)
127 block0():
128 %8 : bool = aten::Bool(%b.1)
129 %9 : int = prim::If(%8)
130 block0():
131 -> (%2)
132 block1():
133 -> (%3)
134 -> (%9)
135 block1():
136 %10 : bool = aten::Bool(%b.1)
137 %11 : int = prim::If(%10)
138 block0():
139 -> (%4)
140 block1():
141 -> (%5)
142 -> (%11)
143 %8 : bool = aten::Bool(%b.1)
144 %9 : int = prim::If(%8)
145 block0():
146 -> (%2)
147 block1():
148 -> (%3)
149 %10 : bool = aten::Bool(%b.1)
150 %11 : int = prim::If(%10)
151 block0():
152 -> (%4)
153 block1():
154 -> (%5)
155 return (%7)
156)IR";
157 auto ordering = traverse_depth_first(graph_string);
158 assert_ordering(
159 ordering,
160 {"%2 : int = prim::Constant[value=10]()",
161 "%3 : int = prim::Constant[value=20]()",
162 "%4 : int = prim::Constant[value=30]()",
163 "%5 : int = prim::Constant[value=40]()",
164 "%6 : bool = aten::Bool(%a.1)",
165 "%7 : int = prim::If(%6)",
166 "%8 : bool = aten::Bool(%b.1)",
167 "%9 : int = prim::If(%8)",
168 "%10 : bool = aten::Bool(%b.1)",
169 "%11 : int = prim::If(%10)",
170 "%12 : bool = aten::Bool(%b.1)",
171 "%13 : int = prim::If(%12)",
172 "%14 : bool = aten::Bool(%b.1)",
173 "%15 : int = prim::If(%14)"});
174}
175
176TEST(GraphIteratorTest, GraphWithLoop) {
177 const auto graph_string = R"IR(
178graph(%a.1 : Tensor):
179 %1 : bool = prim::Constant[value=1]()
180 %2 : int = prim::Constant[value=10]()
181 %3 : int = prim::Constant[value=1]()
182 %4 : Tensor = prim::Loop(%2, %1, %a.1)
183 block0(%i : int, %b.9 : Tensor):
184 %5 : Tensor = aten::add_(%b.9, %3, %3)
185 -> (%1, %5)
186 %6 : Tensor = prim::Loop(%2, %1, %a.1)
187 block0(%i : int, %b.9 : Tensor):
188 -> (%1, %4)
189 return (%6)
190)IR";
191 auto ordering = traverse_depth_first(graph_string);
192 assert_ordering(
193 ordering,
194 {"%1 : bool = prim::Constant[value=1]()",
195 "%2 : int = prim::Constant[value=10]()",
196 "%3 : int = prim::Constant[value=1]()",
197 "%4 : Tensor = prim::Loop(%2, %1, %a.1)",
198 "%7 : Tensor = aten::add_(%b.10, %3, %3)",
199 "%8 : Tensor = prim::Loop(%2, %1, %a.1)"});
200}
201
202} // namespace jit
203} // namespace torch
204