1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/jit/test_utils.h> |
4 | #include <torch/csrc/jit/operator_upgraders/upgraders.h> |
5 | #include <torch/csrc/jit/operator_upgraders/version_map.h> |
6 | #include <torch/csrc/jit/passes/replacement_of_old_operators.h> |
7 | #include <memory> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | std::unordered_map<std::string, std::string> test_upgraders( |
13 | {{"_test_serialization_subcmul_0_2" , R"IR(graph(%self.1 : Tensor, |
14 | %other.1 : Tensor, |
15 | %alpha.1 : Union(float, int)): |
16 | %7 : int = prim::Constant[value=1]() |
17 | %6 : Tensor = aten::mul(%self.1, %alpha.1) # torch/jit/operator_upgraders.py:18:20 |
18 | %8 : Tensor = aten::sub(%other.1, %6, %7) # torch/jit/operator_upgraders.py:18:11 |
19 | return (%8))IR" }, |
20 | {"div_Tensor_0_3" , R"IR(graph(%self.1 : Tensor, |
21 | %other.1 : Tensor): |
22 | %32 : str = prim::Constant[value="trunc"]() |
23 | %6 : bool = prim::Constant[value=1]() |
24 | %4 : bool = aten::is_floating_point(%self.1) |
25 | %11 : bool = prim::If(%4) |
26 | block0(): |
27 | -> (%6) |
28 | block1(): |
29 | %9 : bool = aten::is_floating_point(%other.1) |
30 | -> (%9) |
31 | %35 : Tensor = prim::If(%11) |
32 | block0(): |
33 | %36 : Tensor = aten::div(%self.1, %other.1) |
34 | -> (%36) |
35 | block1(): |
36 | %37 : Tensor = aten::div(%self.1, %other.1, %32) |
37 | -> (%37) |
38 | return (%35))IR" }}); |
39 | |
40 | TEST(OpReplacementTest, ReplaceDivInSimpleFunction) { |
41 | const auto graph_string = R"IR( |
42 | graph(%0 : Tensor, |
43 | %1 : Tensor): |
44 | %2 : Tensor = aten::add(%0, %1) |
45 | %3 : Tensor = aten::div(%2, %1) |
46 | return (%3))IR" ; |
47 | auto g = std::make_shared<Graph>(); |
48 | test_only_populate_upgraders(test_upgraders); |
49 | torch::jit::parseIR(graph_string, g.get()); |
50 | g->set_op_version(2); |
51 | ReplaceOldOperatorsWithUpgraders(g); |
52 | testing::FileCheck() |
53 | .check("prim::If" ) |
54 | ->check_count("aten::div(%2, %1)" , 1, /*exactly=*/true) |
55 | ->check_count("aten::div(%2, %1, %4)" , 1, /*exactly=*/true) |
56 | ->run(*g); |
57 | } |
58 | |
59 | TEST(OpReplacementTest, ReplaceTwoOpsInSimpleFunction) { |
60 | const auto graph_string = R"IR( |
61 | graph(%0 : Tensor, |
62 | %1 : Tensor): |
63 | %2 : Tensor = aten::add(%0, %1) |
64 | %3 : Tensor = aten::div(%2, %1) |
65 | %4 : int = prim::Constant[value=1]() |
66 | %5: Tensor = aten::_test_serialization_subcmul(%0, %1, %4) |
67 | return (%3, %5))IR" ; |
68 | auto g = std::make_shared<Graph>(); |
69 | test_only_populate_upgraders(test_upgraders); |
70 | UpgraderEntry test_entry{ |
71 | 3, |
72 | "_test_serialization_subcmul_0_2" , |
73 | "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor" }; |
74 | test_only_add_entry("aten::_test_serialization_subcmul" , test_entry); |
75 | torch::jit::parseIR(graph_string, g.get()); |
76 | g->set_op_version(2); |
77 | ReplaceOldOperatorsWithUpgraders(g); |
78 | testing::FileCheck() |
79 | .check("prim::If" ) |
80 | ->check_count("aten::div" , 2, /*exactly=*/true) |
81 | ->run(*g); |
82 | test_only_remove_entry("aten::_test_serialization_subcmul" ); |
83 | test_only_remove_upgraders(test_upgraders); |
84 | } |
85 | |
86 | TEST(OpReplacementTest, ReplaceDivInNestedFunction) { |
87 | const auto graph_string = R"IR( |
88 | graph(%0 : Tensor, |
89 | %1 : Tensor, |
90 | %8 : bool): |
91 | %9 : bool = prim::Constant[value=1]() |
92 | %7 : bool = prim::If(%8) |
93 | block0(): |
94 | -> (%9) |
95 | block1(): |
96 | %2 : Tensor = aten::add(%0, %1) |
97 | %3 : Tensor = aten::div(%2, %1) |
98 | %4 : Tensor = aten::add(%3, %0) |
99 | %10 : bool = aten::is_floating_point(%4) |
100 | -> (%10) |
101 | return (%7))IR" ; |
102 | auto g = std::make_shared<Graph>(); |
103 | test_only_populate_upgraders(test_upgraders); |
104 | torch::jit::parseIR(graph_string, g.get()); |
105 | g->set_op_version(2); |
106 | ReplaceOldOperatorsWithUpgraders(g); |
107 | testing::FileCheck() |
108 | .check("prim::If" ) |
109 | ->check_count("aten::add" , 2, false) |
110 | ->run(*g); |
111 | |
112 | testing::FileCheck() |
113 | .check("prim::If" ) |
114 | ->check_count("aten::div" , 2, false) |
115 | ->run(*g); |
116 | test_only_remove_upgraders(test_upgraders); |
117 | } |
118 | |
119 | TEST(OpReplacementTest, ReplaceTestSubcmulInSimpleFunction) { |
120 | const auto graph_string = R"IR( |
121 | graph(%0 : Tensor, |
122 | %1 : Tensor): |
123 | %3 : int = prim::Constant[value=1]() |
124 | %2 : Tensor = aten::_test_serialization_subcmul(%0, %1, %3) |
125 | return (%2))IR" ; |
126 | auto g = std::make_shared<Graph>(); |
127 | test_only_populate_upgraders(test_upgraders); |
128 | UpgraderEntry test_entry{ |
129 | 3, |
130 | "_test_serialization_subcmul_0_2" , |
131 | "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor" }; |
132 | test_only_add_entry("aten::_test_serialization_subcmul" , test_entry); |
133 | torch::jit::parseIR(graph_string, g.get()); |
134 | g->set_op_version(2); |
135 | ReplaceOldOperatorsWithUpgraders(g); |
136 | testing::FileCheck().check_count("aten::mul" , 1, false)->run(*g); |
137 | |
138 | testing::FileCheck().check_count("aten::sub" , 1, false)->run(*g); |
139 | |
140 | test_only_remove_upgraders(test_upgraders); |
141 | test_only_remove_entry("aten::_test_serialization_subcmul" ); |
142 | } |
143 | |
144 | } // namespace jit |
145 | } // namespace torch |
146 | |