1 | #include <gtest/gtest.h> |
2 | |
3 | #include "test/cpp/jit/test_utils.h" |
4 | |
5 | #include <torch/csrc/jit/testing/file_check.h> |
6 | #include "torch/csrc/jit/passes/common_subexpression_elimination.h" |
7 | #include "torch/csrc/jit/passes/utils/subgraph_utils.h" |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | TEST(SubgraphUtilsTest, Basic) { |
13 | auto graph = build_lstm(); |
14 | EliminateCommonSubexpression(graph); |
15 | |
16 | std::vector<Node*> originalNodes( |
17 | graph->nodes().begin(), graph->nodes().end()); |
18 | |
19 | for (bool reverse_iterate : {true, false}) { |
20 | // Merge everything into a single subgraph |
21 | bool first = true; |
22 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
23 | Node* subgraph; |
24 | auto it = |
25 | reverse_iterate ? graph->nodes().rbegin() : graph->nodes().begin(); |
26 | auto end = reverse_iterate ? graph->nodes().rend() : graph->nodes().end(); |
27 | for (; it != end;) { |
28 | if (first) { |
29 | subgraph = SubgraphUtils::createSingletonSubgraph( |
30 | *it, prim::DifferentiableGraph); |
31 | it = reverse_iterate ? ++subgraph->reverseIterator() |
32 | : ++subgraph->iterator(); |
33 | first = false; |
34 | } |
35 | |
36 | SubgraphUtils::mergeNodeIntoSubgraph(*it, subgraph); |
37 | it = reverse_iterate ? ++subgraph->reverseIterator() |
38 | : ++subgraph->iterator(); |
39 | } |
40 | |
41 | // Unmerge and compare with original node listing |
42 | // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) |
43 | SubgraphUtils::unmergeSubgraph(subgraph); |
44 | EliminateCommonSubexpression(graph); |
45 | |
46 | std::vector<Node*> newNodes(graph->nodes().begin(), graph->nodes().end()); |
47 | ASSERT_EQ(originalNodes.size(), newNodes.size()); |
48 | } |
49 | } |
50 | |
51 | TEST(SubgraphUtilsTest, MergeSubgraphs) { |
52 | auto graph = std::make_shared<Graph>(); |
53 | std::unordered_map<std::string, Value*> parse_map; |
54 | parseIR( |
55 | R"IR( |
56 | graph(%a : Tensor, %b : Tensor, %c : Tensor): |
57 | %x : Tensor = aten::sigmoid(%a) |
58 | %y : Tensor = aten::mul(%a, %b) |
59 | %p : Tensor = aten::div(%c, %b) |
60 | %q1 : Tensor = aten::mul(%p, %a) |
61 | %q2 : Tensor = aten::tanh(%q1) |
62 | %q3 : Tensor = aten::tanh(%q2) |
63 | %q4 : Tensor = aten::tanh(%q3) |
64 | %q5 : Tensor = aten::hardsigmoid(%q4) |
65 | return (%x, %y, %q5))IR" , |
66 | &*graph, |
67 | parse_map); |
68 | |
69 | std::vector<Node*> originalNodes( |
70 | graph->nodes().begin(), graph->nodes().end()); |
71 | for (bool reverse_merge : {true, false}) { |
72 | // Merge everything into two adjacent subgraphs |
73 | Node* graph1 = SubgraphUtils::createSingletonSubgraph( |
74 | *graph->nodes().begin(), prim::DifferentiableGraph); |
75 | while (true) { |
76 | Node* next = graph1->next(); |
77 | if (next->kind() == aten::tanh) { |
78 | break; |
79 | } |
80 | SubgraphUtils::mergeNodeIntoSubgraph(next, graph1); |
81 | } |
82 | Node* graph2 = SubgraphUtils::createSingletonSubgraph( |
83 | graph1->next(), prim::DifferentiableGraph); |
84 | while (graph2->next() != *graph->nodes().end()) { |
85 | SubgraphUtils::mergeNodeIntoSubgraph(graph2->next(), graph2); |
86 | } |
87 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
88 | Node* subgraph; |
89 | if (reverse_merge) { |
90 | SubgraphUtils::mergeNodeIntoSubgraph(graph2, graph1); |
91 | subgraph = graph1; |
92 | } else { |
93 | SubgraphUtils::mergeNodeIntoSubgraph(graph1, graph2); |
94 | subgraph = graph2; |
95 | } |
96 | auto run_file_check = [](std::shared_ptr<Graph> graph) { |
97 | graph->lint(); |
98 | testing::FileCheck() |
99 | .check("aten::sigmoid" ) |
100 | ->check("aten::mul" ) |
101 | ->check("aten::div" ) |
102 | ->check("aten::mul" ) |
103 | ->check_count("aten::tanh" , 3) |
104 | ->check("aten::hardsigmoid" ) |
105 | ->run(*graph); |
106 | }; |
107 | run_file_check(subgraph->g(attr::Subgraph)); |
108 | |
109 | // Unmerge and compare with original node listing |
110 | SubgraphUtils::unmergeSubgraph(subgraph); |
111 | EliminateCommonSubexpression(graph); |
112 | run_file_check(graph); |
113 | |
114 | std::vector<Node*> newNodes(graph->nodes().begin(), graph->nodes().end()); |
115 | ASSERT_EQ(originalNodes.size(), newNodes.size()); |
116 | } |
117 | } |
118 | |
119 | TEST(SubgraphUtilsTest, GraphName) { |
120 | auto graph = std::make_shared<Graph>(); |
121 | |
122 | std::unordered_map<std::string, Value*> parse_map; |
123 | parseIR( |
124 | R"IR( |
125 | graph(%a : Tensor, %b : Tensor, %c : Tensor): |
126 | %x : Tensor = aten::tanh(%a) |
127 | %y : Tensor = aten::mul(%a, %b) |
128 | %p : Tensor = aten::div(%c, %b) |
129 | %q1 : Tensor = aten::mul(%p, %a) |
130 | %q2 : Tensor = aten::tanh(%q1) |
131 | %q3 : Tensor = aten::tanh(%q2) |
132 | %q4 : Tensor = aten::tanh(%q3) |
133 | %q5 : Tensor = aten::tanh(%q4) |
134 | return (%x, %y, %q5))IR" , |
135 | &*graph, |
136 | parse_map); |
137 | std::string ref_full_name = "graph_tanh_mul_div_mul_tanh_tanh_tanh_tanh" ; |
138 | std::string full_name = |
139 | SubgraphUtils::generateNameForGraph(graph, 80, "graph" ); |
140 | ASSERT_EQ(full_name, ref_full_name); |
141 | |
142 | std::string truncated_name = |
143 | SubgraphUtils::generateNameForGraph(graph, 10, "graph" ); |
144 | |
145 | ASSERT_LE(truncated_name.size(), ref_full_name.size()); |
146 | } |
147 | |
148 | } // namespace jit |
149 | } // namespace torch |
150 | |