1#include <torch/csrc/jit/passes/quantization/fusion_passes.h>
2#include <torch/csrc/jit/passes/subgraph_rewrite.h>
3
4namespace torch {
5namespace jit {
6
7namespace {
8void fuseQuantizeAddReluImpl(std::shared_ptr<Graph>& graph) {
9 SubgraphRewriter fused_add_relu_rewriter;
10 std::string quantized_add_relu_pattern = R"(
11 graph(%a_quant, %b_quant, %scale, %zero_point):
12 %add_out = quantized::add(%a_quant, %b_quant, %scale, %zero_point)
13 %r = aten::relu(%add_out)
14 return (%r) )";
15 std::string fused_add_relu_pattern = R"(
16 graph(%a_quant, %b_quant, %scale, %zero_point):
17 %r = quantized::add_relu(%a_quant, %b_quant, %scale, %zero_point)
18 return (%r) )";
19 fused_add_relu_rewriter.RegisterRewritePattern(
20 quantized_add_relu_pattern, fused_add_relu_pattern);
21 std::string quantized_add_out_relu_pattern = R"(
22 graph(%a_quant, %b_quant, %out_quant):
23 %add_out = quantized::add_out(%a_quant, %b_quant, %out_quant)
24 %r = aten::relu(%add_out)
25 return (%r) )";
26 std::string fused_add_out_relu_pattern = R"(
27 graph(%a_quant, %b_quant, %out_quant):
28 %r = quantized::add_relu_out(%a_quant, %b_quant, %out_quant)
29 return (%r) )";
30 fused_add_relu_rewriter.RegisterRewritePattern(
31 quantized_add_out_relu_pattern, fused_add_out_relu_pattern);
32 std::string quantized_add_scalar_relu_pattern = R"(
33 graph(%a_quant, %b_scalar):
34 %add_out = quantized::add_scalar(%a_quant, %b_scalar)
35 %r = aten::relu(%add_out)
36 return (%r) )";
37 std::string fused_add_scalar_relu_pattern = R"(
38 graph(%a_quant, %b_scalar):
39 %r = quantized::add_scalar_relu(%a_quant, %b_scalar)
40 return (%r) )";
41 fused_add_relu_rewriter.RegisterRewritePattern(
42 quantized_add_scalar_relu_pattern, fused_add_scalar_relu_pattern);
43 std::string quantized_add_scalar_out_relu_pattern = R"(
44 graph(%a_quant, %b_scalar, %out_quant):
45 %add_out = quantized::add_scalar_out(%a_quant, %b_scalar, %out_quant)
46 %r = aten::relu(%add_out)
47 return (%r) )";
48 std::string fused_add_scalar_out_relu_pattern = R"(
49 graph(%a_quant, %b_scalar, %out_quant):
50 %r = quantized::add_scalar_relu_out(%a_quant, %b_scalar, %out_quant)
51 return (%r) )";
52 fused_add_relu_rewriter.RegisterRewritePattern(
53 quantized_add_scalar_out_relu_pattern, fused_add_scalar_out_relu_pattern);
54 fused_add_relu_rewriter.runOnGraph(graph);
55}
56} // namespace
57
58void FuseQuantizedAddRelu(std::shared_ptr<Graph>& graph) {
59 fuseQuantizeAddReluImpl(graph);
60}
61
62} // namespace jit
63} // namespace torch
64