1#include <torch/csrc/jit/passes/fuse_linear.h>
2#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
3#include <torch/csrc/jit/passes/quantization/helper.h>
4#include <torch/csrc/jit/passes/subgraph_rewrite.h>
5
6namespace torch {
7namespace jit {
8
9void FuseLinear(std::shared_ptr<Graph>& graph) {
10 std::string addmm_pattern = R"IR(
11 graph(%input, %weight_t, %bias, %beta, %alpha):
12 %res = aten::addmm(%bias, %input, %weight_t, %beta, %alpha)
13 return (%res))IR";
14 std::string fused_linear_addmm = R"IR(
15 graph(%input, %weight_t, %bias, %beta, %alpha):
16 %weight = aten::t(%weight_t)
17 %res = aten::linear(%input, %weight, %bias)
18 return (%res))IR";
19
20 auto beta_is_one = [](const Match& match,
21 const std::unordered_map<std::string, Value*>& vmap) {
22 return is_int_constant(match, vmap, "beta", 1);
23 };
24
25 // check %weight_t is produced by `aten::t` to make sure
26 // we can transform the pattern to `aten::linear`
27 auto weight_transposed =
28 [](const Match& match,
29 const std::unordered_map<std::string, Value*>& vmap) {
30 const auto& match_vmap = match.values_map;
31 auto v = match_vmap.at(vmap.at("weight_t"));
32 return v->node()->kind() == Symbol::aten("t");
33 };
34
35 // replace addmm pattern to linear
36 SubgraphRewriter addmm_to_linear;
37 std::vector<std::pair<std::string, std::string>> value_mappings(
38 {{"weight", "res"}, {"res", "res"}});
39 addmm_to_linear.RegisterRewritePattern(
40 addmm_pattern, fused_linear_addmm, value_mappings);
41 addmm_to_linear.runOnGraph(
42 graph, {aten_add_alpha_is_one, beta_is_one, weight_transposed});
43
44 std::string matmul_add_pattern = R"IR(
45 graph(%input, %weight_t, %bias, %alpha):
46 %output = aten::matmul(%input, %weight_t)
47 %res = aten::add_(%output, %bias, %alpha)
48 return (%res))IR";
49 std::string fused_linear_matmul = R"IR(
50 graph(%input, %weight_t, %bias, %alpha):
51 %weight = aten::t(%weight_t)
52 %res = aten::linear(%input, %weight, %bias)
53 return (%res))IR";
54 value_mappings = {{"weight", "output"}, {"res", "output"}};
55 // replace matmul + add pattern to linear
56 SubgraphRewriter matmuladd_to_linear;
57 matmuladd_to_linear.RegisterRewritePattern(
58 matmul_add_pattern, fused_linear_matmul, value_mappings);
59 matmuladd_to_linear.runOnGraph(
60 graph, {aten_add_alpha_is_one, weight_transposed});
61
62 std::string matmul_pattern = R"IR(
63 graph(%input, %weight_t):
64 %output = aten::matmul(%input, %weight_t)
65 return (%output))IR";
66 std::string fused_linear_bias_none = R"IR(
67 graph(%input, %weight_t):
68 %weight = aten::t(%weight_t)
69 %bias: Tensor? = prim::Constant()
70 %res = aten::linear(%input, %weight, %bias)
71 return (%res))IR";
72
73 // replace matmul with bias=None pattern to linear
74 SubgraphRewriter matmul_to_linear;
75 matmul_to_linear.RegisterRewritePattern(
76 matmul_pattern, fused_linear_bias_none, value_mappings);
77 matmul_to_linear.runOnGraph(graph, weight_transposed);
78
79 // clean up extra transpose for the weight of aten::linear
80 std::string linear_weight_extra_transpose = R"IR(
81 graph(%input, %weight, %bias):
82 %weight_t1 = aten::t(%weight)
83 %weight_t2 = aten::t(%weight_t1)
84 %res = aten::linear(%input, %weight_t2, %bias)
85 return (%res))IR";
86
87 std::string linear_weight_no_transpose = R"IR(
88 graph(%input, %weight, %bias):
89 %res = aten::linear(%input, %weight, %bias)
90 return (%res))IR";
91
92 value_mappings = {{"res", "res"}};
93 SubgraphRewriter cleanup;
94 cleanup.RegisterRewritePattern(
95 linear_weight_extra_transpose,
96 linear_weight_no_transpose,
97 value_mappings);
98 cleanup.runOnGraph(graph);
99
100 SwapFunctionalLinear(graph);
101}
102
103void SwapFunctionalLinear(Module& module) {
104 for (auto& method : module.get_methods()) {
105 std::shared_ptr<Graph> g = method.graph();
106 SwapFunctionalLinear(g);
107 }
108 for (Module m : module.children()) {
109 SwapFunctionalLinear(m);
110 }
111}
112
113void SwapFunctionalLinear(std::shared_ptr<Graph>& graph) {
114 std::string functional_linear = R"(
115graph(%linear, %input, %weight, %bias):
116 %r = prim::CallFunction(%linear, %input, %weight, %bias)
117 return (%r) )";
118 std::string aten_linear = R"(
119graph(%linear, %input, %weight, %bias):
120 %r = aten::linear(%input, %weight, %bias)
121 return (%r) )";
122
123 auto filter = [](const Match& match,
124 const std::unordered_map<std::string, Value*>& vmap) {
125 const auto& match_vmap = match.values_map;
126 auto linear = graph_rewrite_helper::getValue("linear", match_vmap, vmap);
127 auto func_name = graph_rewrite_helper::getFuncName(linear);
128 return func_name == "linear";
129 };
130 SubgraphRewriter rewriter;
131 rewriter.RegisterRewritePattern(functional_linear, aten_linear);
132 rewriter.runOnGraph(graph, filter);
133}
134
135} // namespace jit
136} // namespace torch
137