1#include <torch/csrc/jit/passes/fuse_relu.h>
2
3#include <torch/csrc/jit/ir/ir.h>
4#include <torch/csrc/jit/ir/subgraph_matcher.h>
5#include <torch/csrc/jit/passes/subgraph_rewrite.h>
6
7namespace torch {
8namespace jit {
9
10namespace {
11void fuseAddReluImpl(std::shared_ptr<Graph>& graph) {
12 SubgraphRewriter rewriter;
13
14 std::string add_relu_0 = R"(
15 graph(%a, %b, %alpha):
16 %add_res = aten::add(%a, %b, %alpha)
17 %res = aten::relu(%add_res)
18 return (%res))";
19 std::string add_relu_fused = R"(
20 graph(%a, %b, %alpha):
21 %res = aten::_add_relu(%a, %b, %alpha)
22 return (%res))";
23 rewriter.RegisterRewritePattern(add_relu_0, add_relu_fused);
24
25 std::string add_relu_1 = R"(
26 graph(%a, %b, %alpha):
27 %add_res = aten::add(%a, %b, %alpha)
28 %res = aten::relu_(%add_res)
29 return (%res))";
30 rewriter.RegisterRewritePattern(add_relu_1, add_relu_fused);
31
32 std::string add_inplace_relu_1 = R"(
33 graph(%a, %b, %alpha):
34 %add_res = aten::add_(%a, %b, %alpha)
35 %res = aten::relu_(%add_res)
36 return (%res))";
37 std::string add_inplace_relu_fused = R"(
38 graph(%a, %b, %alpha):
39 %res = aten::_add_relu_(%a, %b, %alpha)
40 return (%res))";
41 rewriter.RegisterRewritePattern(add_inplace_relu_1, add_inplace_relu_fused);
42
43 std::string add_out_relu = R"(
44 graph(%a, %b, %alpha, %out):
45 %add_res = aten::add(%a, %b, %alpha, %out)
46 %res = aten::relu_(%add_res)
47 return (%res))";
48 std::string add_out_relu_fused = R"(
49 graph(%a, %b, %alpha, %out):
50 %res = aten::_add_relu(%a, %b, %alpha, %out)
51 return (%res))";
52
53 rewriter.RegisterRewritePattern(add_out_relu, add_out_relu_fused);
54
55 rewriter.runOnGraph(graph);
56 // NB: Patterns that are left out are add_ + relu and add_out + relu
57 // This is because inplace mutation of the testor done by add_ will be lost if
58 // inplace mutatation of the same tensor actually does add+relu
59}
60} // namespace
61
62void FuseAddRelu(script::Module& module) {
63 auto graph = module.get_method("forward").graph();
64 fuseAddReluImpl(graph);
65}
66
67void FuseAddRelu(std::shared_ptr<Graph>& graph) {
68 fuseAddReluImpl(graph);
69}
70} // namespace jit
71} // namespace torch
72