1#include <gtest/gtest.h>
2
3#include "test/cpp/jit/test_utils.h"
4
5#include <torch/csrc/jit/testing/file_check.h>
6#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
7#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
8
9namespace torch {
10namespace jit {
11
12TEST(SubgraphUtilsTest, Basic) {
13 auto graph = build_lstm();
14 EliminateCommonSubexpression(graph);
15
16 std::vector<Node*> originalNodes(
17 graph->nodes().begin(), graph->nodes().end());
18
19 for (bool reverse_iterate : {true, false}) {
20 // Merge everything into a single subgraph
21 bool first = true;
22 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
23 Node* subgraph;
24 auto it =
25 reverse_iterate ? graph->nodes().rbegin() : graph->nodes().begin();
26 auto end = reverse_iterate ? graph->nodes().rend() : graph->nodes().end();
27 for (; it != end;) {
28 if (first) {
29 subgraph = SubgraphUtils::createSingletonSubgraph(
30 *it, prim::DifferentiableGraph);
31 it = reverse_iterate ? ++subgraph->reverseIterator()
32 : ++subgraph->iterator();
33 first = false;
34 }
35
36 SubgraphUtils::mergeNodeIntoSubgraph(*it, subgraph);
37 it = reverse_iterate ? ++subgraph->reverseIterator()
38 : ++subgraph->iterator();
39 }
40
41 // Unmerge and compare with original node listing
42 // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
43 SubgraphUtils::unmergeSubgraph(subgraph);
44 EliminateCommonSubexpression(graph);
45
46 std::vector<Node*> newNodes(graph->nodes().begin(), graph->nodes().end());
47 ASSERT_EQ(originalNodes.size(), newNodes.size());
48 }
49}
50
51TEST(SubgraphUtilsTest, MergeSubgraphs) {
52 auto graph = std::make_shared<Graph>();
53 std::unordered_map<std::string, Value*> parse_map;
54 parseIR(
55 R"IR(
56graph(%a : Tensor, %b : Tensor, %c : Tensor):
57 %x : Tensor = aten::sigmoid(%a)
58 %y : Tensor = aten::mul(%a, %b)
59 %p : Tensor = aten::div(%c, %b)
60 %q1 : Tensor = aten::mul(%p, %a)
61 %q2 : Tensor = aten::tanh(%q1)
62 %q3 : Tensor = aten::tanh(%q2)
63 %q4 : Tensor = aten::tanh(%q3)
64 %q5 : Tensor = aten::hardsigmoid(%q4)
65 return (%x, %y, %q5))IR",
66 &*graph,
67 parse_map);
68
69 std::vector<Node*> originalNodes(
70 graph->nodes().begin(), graph->nodes().end());
71 for (bool reverse_merge : {true, false}) {
72 // Merge everything into two adjacent subgraphs
73 Node* graph1 = SubgraphUtils::createSingletonSubgraph(
74 *graph->nodes().begin(), prim::DifferentiableGraph);
75 while (true) {
76 Node* next = graph1->next();
77 if (next->kind() == aten::tanh) {
78 break;
79 }
80 SubgraphUtils::mergeNodeIntoSubgraph(next, graph1);
81 }
82 Node* graph2 = SubgraphUtils::createSingletonSubgraph(
83 graph1->next(), prim::DifferentiableGraph);
84 while (graph2->next() != *graph->nodes().end()) {
85 SubgraphUtils::mergeNodeIntoSubgraph(graph2->next(), graph2);
86 }
87 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
88 Node* subgraph;
89 if (reverse_merge) {
90 SubgraphUtils::mergeNodeIntoSubgraph(graph2, graph1);
91 subgraph = graph1;
92 } else {
93 SubgraphUtils::mergeNodeIntoSubgraph(graph1, graph2);
94 subgraph = graph2;
95 }
96 auto run_file_check = [](std::shared_ptr<Graph> graph) {
97 graph->lint();
98 testing::FileCheck()
99 .check("aten::sigmoid")
100 ->check("aten::mul")
101 ->check("aten::div")
102 ->check("aten::mul")
103 ->check_count("aten::tanh", 3)
104 ->check("aten::hardsigmoid")
105 ->run(*graph);
106 };
107 run_file_check(subgraph->g(attr::Subgraph));
108
109 // Unmerge and compare with original node listing
110 SubgraphUtils::unmergeSubgraph(subgraph);
111 EliminateCommonSubexpression(graph);
112 run_file_check(graph);
113
114 std::vector<Node*> newNodes(graph->nodes().begin(), graph->nodes().end());
115 ASSERT_EQ(originalNodes.size(), newNodes.size());
116 }
117}
118
119TEST(SubgraphUtilsTest, GraphName) {
120 auto graph = std::make_shared<Graph>();
121
122 std::unordered_map<std::string, Value*> parse_map;
123 parseIR(
124 R"IR(
125graph(%a : Tensor, %b : Tensor, %c : Tensor):
126 %x : Tensor = aten::tanh(%a)
127 %y : Tensor = aten::mul(%a, %b)
128 %p : Tensor = aten::div(%c, %b)
129 %q1 : Tensor = aten::mul(%p, %a)
130 %q2 : Tensor = aten::tanh(%q1)
131 %q3 : Tensor = aten::tanh(%q2)
132 %q4 : Tensor = aten::tanh(%q3)
133 %q5 : Tensor = aten::tanh(%q4)
134 return (%x, %y, %q5))IR",
135 &*graph,
136 parse_map);
137 std::string ref_full_name = "graph_tanh_mul_div_mul_tanh_tanh_tanh_tanh";
138 std::string full_name =
139 SubgraphUtils::generateNameForGraph(graph, 80, "graph");
140 ASSERT_EQ(full_name, ref_full_name);
141
142 std::string truncated_name =
143 SubgraphUtils::generateNameForGraph(graph, 10, "graph");
144
145 ASSERT_LE(truncated_name.size(), ref_full_name.size());
146}
147
148} // namespace jit
149} // namespace torch
150