1 | #include <torch/csrc/jit/passes/fuse_relu.h> |
2 | |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | #include <torch/csrc/jit/ir/subgraph_matcher.h> |
5 | #include <torch/csrc/jit/passes/subgraph_rewrite.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | |
10 | namespace { |
11 | void fuseAddReluImpl(std::shared_ptr<Graph>& graph) { |
12 | SubgraphRewriter rewriter; |
13 | |
14 | std::string add_relu_0 = R"( |
15 | graph(%a, %b, %alpha): |
16 | %add_res = aten::add(%a, %b, %alpha) |
17 | %res = aten::relu(%add_res) |
18 | return (%res))" ; |
19 | std::string add_relu_fused = R"( |
20 | graph(%a, %b, %alpha): |
21 | %res = aten::_add_relu(%a, %b, %alpha) |
22 | return (%res))" ; |
23 | rewriter.RegisterRewritePattern(add_relu_0, add_relu_fused); |
24 | |
25 | std::string add_relu_1 = R"( |
26 | graph(%a, %b, %alpha): |
27 | %add_res = aten::add(%a, %b, %alpha) |
28 | %res = aten::relu_(%add_res) |
29 | return (%res))" ; |
30 | rewriter.RegisterRewritePattern(add_relu_1, add_relu_fused); |
31 | |
32 | std::string add_inplace_relu_1 = R"( |
33 | graph(%a, %b, %alpha): |
34 | %add_res = aten::add_(%a, %b, %alpha) |
35 | %res = aten::relu_(%add_res) |
36 | return (%res))" ; |
37 | std::string add_inplace_relu_fused = R"( |
38 | graph(%a, %b, %alpha): |
39 | %res = aten::_add_relu_(%a, %b, %alpha) |
40 | return (%res))" ; |
41 | rewriter.RegisterRewritePattern(add_inplace_relu_1, add_inplace_relu_fused); |
42 | |
43 | std::string add_out_relu = R"( |
44 | graph(%a, %b, %alpha, %out): |
45 | %add_res = aten::add(%a, %b, %alpha, %out) |
46 | %res = aten::relu_(%add_res) |
47 | return (%res))" ; |
48 | std::string add_out_relu_fused = R"( |
49 | graph(%a, %b, %alpha, %out): |
50 | %res = aten::_add_relu(%a, %b, %alpha, %out) |
51 | return (%res))" ; |
52 | |
53 | rewriter.RegisterRewritePattern(add_out_relu, add_out_relu_fused); |
54 | |
55 | rewriter.runOnGraph(graph); |
56 | // NB: Patterns that are left out are add_ + relu and add_out + relu |
57 | // This is because inplace mutation of the testor done by add_ will be lost if |
58 | // inplace mutatation of the same tensor actually does add+relu |
59 | } |
60 | } // namespace |
61 | |
62 | void FuseAddRelu(script::Module& module) { |
63 | auto graph = module.get_method("forward" ).graph(); |
64 | fuseAddReluImpl(graph); |
65 | } |
66 | |
67 | void FuseAddRelu(std::shared_ptr<Graph>& graph) { |
68 | fuseAddReluImpl(graph); |
69 | } |
70 | } // namespace jit |
71 | } // namespace torch |
72 | |