1 | #include <torch/csrc/jit/passes/quantization/fusion_passes.h> |
2 | #include <torch/csrc/jit/passes/subgraph_rewrite.h> |
3 | |
4 | namespace torch { |
5 | namespace jit { |
6 | |
7 | namespace { |
8 | void fuseQuantizeAddReluImpl(std::shared_ptr<Graph>& graph) { |
9 | SubgraphRewriter fused_add_relu_rewriter; |
10 | std::string quantized_add_relu_pattern = R"( |
11 | graph(%a_quant, %b_quant, %scale, %zero_point): |
12 | %add_out = quantized::add(%a_quant, %b_quant, %scale, %zero_point) |
13 | %r = aten::relu(%add_out) |
14 | return (%r) )" ; |
15 | std::string fused_add_relu_pattern = R"( |
16 | graph(%a_quant, %b_quant, %scale, %zero_point): |
17 | %r = quantized::add_relu(%a_quant, %b_quant, %scale, %zero_point) |
18 | return (%r) )" ; |
19 | fused_add_relu_rewriter.RegisterRewritePattern( |
20 | quantized_add_relu_pattern, fused_add_relu_pattern); |
21 | std::string quantized_add_out_relu_pattern = R"( |
22 | graph(%a_quant, %b_quant, %out_quant): |
23 | %add_out = quantized::add_out(%a_quant, %b_quant, %out_quant) |
24 | %r = aten::relu(%add_out) |
25 | return (%r) )" ; |
26 | std::string fused_add_out_relu_pattern = R"( |
27 | graph(%a_quant, %b_quant, %out_quant): |
28 | %r = quantized::add_relu_out(%a_quant, %b_quant, %out_quant) |
29 | return (%r) )" ; |
30 | fused_add_relu_rewriter.RegisterRewritePattern( |
31 | quantized_add_out_relu_pattern, fused_add_out_relu_pattern); |
32 | std::string quantized_add_scalar_relu_pattern = R"( |
33 | graph(%a_quant, %b_scalar): |
34 | %add_out = quantized::add_scalar(%a_quant, %b_scalar) |
35 | %r = aten::relu(%add_out) |
36 | return (%r) )" ; |
37 | std::string fused_add_scalar_relu_pattern = R"( |
38 | graph(%a_quant, %b_scalar): |
39 | %r = quantized::add_scalar_relu(%a_quant, %b_scalar) |
40 | return (%r) )" ; |
41 | fused_add_relu_rewriter.RegisterRewritePattern( |
42 | quantized_add_scalar_relu_pattern, fused_add_scalar_relu_pattern); |
43 | std::string quantized_add_scalar_out_relu_pattern = R"( |
44 | graph(%a_quant, %b_scalar, %out_quant): |
45 | %add_out = quantized::add_scalar_out(%a_quant, %b_scalar, %out_quant) |
46 | %r = aten::relu(%add_out) |
47 | return (%r) )" ; |
48 | std::string fused_add_scalar_out_relu_pattern = R"( |
49 | graph(%a_quant, %b_scalar, %out_quant): |
50 | %r = quantized::add_scalar_relu_out(%a_quant, %b_scalar, %out_quant) |
51 | return (%r) )" ; |
52 | fused_add_relu_rewriter.RegisterRewritePattern( |
53 | quantized_add_scalar_out_relu_pattern, fused_add_scalar_out_relu_pattern); |
54 | fused_add_relu_rewriter.runOnGraph(graph); |
55 | } |
56 | } // namespace |
57 | |
58 | void FuseQuantizedAddRelu(std::shared_ptr<Graph>& graph) { |
59 | fuseQuantizeAddReluImpl(graph); |
60 | } |
61 | |
62 | } // namespace jit |
63 | } // namespace torch |
64 | |