1 | #include <torch/csrc/jit/passes/fuse_linear.h> |
2 | #include <torch/csrc/jit/passes/graph_rewrite_helper.h> |
3 | #include <torch/csrc/jit/passes/quantization/helper.h> |
4 | #include <torch/csrc/jit/passes/subgraph_rewrite.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | |
9 | void FuseLinear(std::shared_ptr<Graph>& graph) { |
10 | std::string addmm_pattern = R"IR( |
11 | graph(%input, %weight_t, %bias, %beta, %alpha): |
12 | %res = aten::addmm(%bias, %input, %weight_t, %beta, %alpha) |
13 | return (%res))IR" ; |
14 | std::string fused_linear_addmm = R"IR( |
15 | graph(%input, %weight_t, %bias, %beta, %alpha): |
16 | %weight = aten::t(%weight_t) |
17 | %res = aten::linear(%input, %weight, %bias) |
18 | return (%res))IR" ; |
19 | |
20 | auto beta_is_one = [](const Match& match, |
21 | const std::unordered_map<std::string, Value*>& vmap) { |
22 | return is_int_constant(match, vmap, "beta" , 1); |
23 | }; |
24 | |
25 | // check %weight_t is produced by `aten::t` to make sure |
26 | // we can transform the pattern to `aten::linear` |
27 | auto weight_transposed = |
28 | [](const Match& match, |
29 | const std::unordered_map<std::string, Value*>& vmap) { |
30 | const auto& match_vmap = match.values_map; |
31 | auto v = match_vmap.at(vmap.at("weight_t" )); |
32 | return v->node()->kind() == Symbol::aten("t" ); |
33 | }; |
34 | |
35 | // replace addmm pattern to linear |
36 | SubgraphRewriter addmm_to_linear; |
37 | std::vector<std::pair<std::string, std::string>> value_mappings( |
38 | {{"weight" , "res" }, {"res" , "res" }}); |
39 | addmm_to_linear.RegisterRewritePattern( |
40 | addmm_pattern, fused_linear_addmm, value_mappings); |
41 | addmm_to_linear.runOnGraph( |
42 | graph, {aten_add_alpha_is_one, beta_is_one, weight_transposed}); |
43 | |
44 | std::string matmul_add_pattern = R"IR( |
45 | graph(%input, %weight_t, %bias, %alpha): |
46 | %output = aten::matmul(%input, %weight_t) |
47 | %res = aten::add_(%output, %bias, %alpha) |
48 | return (%res))IR" ; |
49 | std::string fused_linear_matmul = R"IR( |
50 | graph(%input, %weight_t, %bias, %alpha): |
51 | %weight = aten::t(%weight_t) |
52 | %res = aten::linear(%input, %weight, %bias) |
53 | return (%res))IR" ; |
54 | value_mappings = {{"weight" , "output" }, {"res" , "output" }}; |
55 | // replace matmul + add pattern to linear |
56 | SubgraphRewriter matmuladd_to_linear; |
57 | matmuladd_to_linear.RegisterRewritePattern( |
58 | matmul_add_pattern, fused_linear_matmul, value_mappings); |
59 | matmuladd_to_linear.runOnGraph( |
60 | graph, {aten_add_alpha_is_one, weight_transposed}); |
61 | |
62 | std::string matmul_pattern = R"IR( |
63 | graph(%input, %weight_t): |
64 | %output = aten::matmul(%input, %weight_t) |
65 | return (%output))IR" ; |
66 | std::string fused_linear_bias_none = R"IR( |
67 | graph(%input, %weight_t): |
68 | %weight = aten::t(%weight_t) |
69 | %bias: Tensor? = prim::Constant() |
70 | %res = aten::linear(%input, %weight, %bias) |
71 | return (%res))IR" ; |
72 | |
73 | // replace matmul with bias=None pattern to linear |
74 | SubgraphRewriter matmul_to_linear; |
75 | matmul_to_linear.RegisterRewritePattern( |
76 | matmul_pattern, fused_linear_bias_none, value_mappings); |
77 | matmul_to_linear.runOnGraph(graph, weight_transposed); |
78 | |
79 | // clean up extra transpose for the weight of aten::linear |
80 | std::string = R"IR( |
81 | graph(%input, %weight, %bias): |
82 | %weight_t1 = aten::t(%weight) |
83 | %weight_t2 = aten::t(%weight_t1) |
84 | %res = aten::linear(%input, %weight_t2, %bias) |
85 | return (%res))IR" ; |
86 | |
87 | std::string linear_weight_no_transpose = R"IR( |
88 | graph(%input, %weight, %bias): |
89 | %res = aten::linear(%input, %weight, %bias) |
90 | return (%res))IR" ; |
91 | |
92 | value_mappings = {{"res" , "res" }}; |
93 | SubgraphRewriter cleanup; |
94 | cleanup.RegisterRewritePattern( |
95 | linear_weight_extra_transpose, |
96 | linear_weight_no_transpose, |
97 | value_mappings); |
98 | cleanup.runOnGraph(graph); |
99 | |
100 | SwapFunctionalLinear(graph); |
101 | } |
102 | |
103 | void SwapFunctionalLinear(Module& module) { |
104 | for (auto& method : module.get_methods()) { |
105 | std::shared_ptr<Graph> g = method.graph(); |
106 | SwapFunctionalLinear(g); |
107 | } |
108 | for (Module m : module.children()) { |
109 | SwapFunctionalLinear(m); |
110 | } |
111 | } |
112 | |
113 | void SwapFunctionalLinear(std::shared_ptr<Graph>& graph) { |
114 | std::string functional_linear = R"( |
115 | graph(%linear, %input, %weight, %bias): |
116 | %r = prim::CallFunction(%linear, %input, %weight, %bias) |
117 | return (%r) )" ; |
118 | std::string aten_linear = R"( |
119 | graph(%linear, %input, %weight, %bias): |
120 | %r = aten::linear(%input, %weight, %bias) |
121 | return (%r) )" ; |
122 | |
123 | auto filter = [](const Match& match, |
124 | const std::unordered_map<std::string, Value*>& vmap) { |
125 | const auto& match_vmap = match.values_map; |
126 | auto linear = graph_rewrite_helper::getValue("linear" , match_vmap, vmap); |
127 | auto func_name = graph_rewrite_helper::getFuncName(linear); |
128 | return func_name == "linear" ; |
129 | }; |
130 | SubgraphRewriter rewriter; |
131 | rewriter.RegisterRewritePattern(functional_linear, aten_linear); |
132 | rewriter.runOnGraph(graph, filter); |
133 | } |
134 | |
135 | } // namespace jit |
136 | } // namespace torch |
137 | |