1#include <ATen/Utils.h>
2
3#include <ATen/code_template.h>
4#include <ATen/cuda/CUDAConfig.h>
5#include <torch/csrc/jit/ir/constants.h>
6#include <torch/csrc/jit/ir/ir.h>
7#include <torch/csrc/jit/ir/subgraph_matcher.h>
8#include <torch/csrc/jit/jit_log.h>
9#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
10#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
11#include <torch/csrc/jit/passes/remove_mutation.h>
12#include <torch/csrc/jit/passes/subgraph_rewrite.h>
13
14namespace torch {
15namespace jit {
16
17namespace {
18void fuseFrozenConvAddReluImpl(std::shared_ptr<Graph>& graph) {
19#if AT_CUDNN_ENABLED() || AT_ROCM_ENABLED()
20 GRAPH_DEBUG("Before fuseFrozenConvAddReluImpl: ", *graph);
21 SubgraphRewriter rewriter;
22
23 // CUDNN does not support conv1d
24 std::array<std::string, 2> conv_operators = {"conv2d", "conv3d"};
25 std::array<std::string, 2> add_operators = {"add", "add_"};
26 std::array<std::string, 2> relu_operators = {"relu", "relu_"};
27
28 auto conv_relu_rstring = at::jit::CodeTemplate(R"(
29 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
30 %x = aten::${conv}(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
31 %res = aten::${relu}(%x)
32 return (%res))");
33
34#ifdef USE_ROCM
35 std::string conv_relu_fused = R"(
36 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
37 %res = aten::miopen_convolution_relu(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
38 return (%res))";
39#else
40 std::string conv_relu_fused = R"(
41 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
42 %res = aten::cudnn_convolution_relu(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
43 return (%res))";
44#endif
45
46 auto conv_add_relu_rstring = at::jit::CodeTemplate(R"(
47 graph(%input, %weight, %bias, %z, %alpha, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
48 %x = aten::${conv}(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
49 %y = aten::${add}(%x, %z, %alpha)
50 %res = aten::${relu}(%y)
51 return (%res))");
52
53#ifdef USE_ROCM
54 std::string conv_add_relu_fused = R"(
55 graph(%input, %weight, %bias, %z, %alpha, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
56 %res = aten::miopen_convolution_add_relu(%input, %weight, %z, %alpha, %bias, %stride, %padding, %dilation, %groups)
57 return (%res))";
58#else
59 std::string conv_add_relu_fused = R"(
60 graph(%input, %weight, %bias, %z, %alpha, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
61 %res = aten::cudnn_convolution_add_relu(%input, %weight, %z, %alpha, %bias, %stride, %padding, %dilation, %groups)
62 return (%res))";
63#endif
64
65 for (const auto& conv : conv_operators) {
66 for (const auto& relu : relu_operators) {
67 at::jit::TemplateEnv env;
68 env.s("conv", conv);
69 env.s("relu", relu);
70 rewriter.RegisterRewritePattern(
71 conv_relu_rstring.format(env), conv_relu_fused);
72 for (const auto& add : add_operators) {
73 env.s("add", add);
74 rewriter.RegisterRewritePattern(
75 conv_add_relu_rstring.format(env), conv_add_relu_fused);
76 }
77 }
78 }
79
80 auto filter = [](const Match& match,
81 const std::unordered_map<std::string, Value*>& vmap) {
82 auto weight = toIValue(match.values_map.at(vmap.at("weight")));
83 if (!weight.has_value() || !weight.value().isTensor()) {
84 return false;
85 }
86 const at::Tensor& weight_t = weight.value().toTensor();
87 if (!weight_t.device().is_cuda() || !weight_t.is_contiguous()) {
88 return false;
89 }
90
91 // bias is optional
92 if (vmap.find("bias") != vmap.end()) {
93 auto bias = toIValue(match.values_map.at(vmap.at("bias")));
94 if (bias.has_value() && bias.value().isTensor()) {
95 const at::Tensor& bias_t = bias.value().toTensor();
96 if (bias_t.dtype() != weight_t.dtype() || bias_t.ndimension() != 1 ||
97 bias_t.size(0) != weight_t.size(0) || !bias_t.device().is_cuda()) {
98 return false;
99 }
100 }
101 }
102
103 // z is optional
104 if (vmap.find("z") != vmap.end()) {
105 auto z = toIValue(match.values_map.at(vmap.at("z")));
106 if (z.has_value() && z.value().isTensor()) {
107 const at::Tensor& z_t = z.value().toTensor();
108 if (z_t.dtype() != weight_t.dtype() ||
109 z_t.size(0) != weight_t.size(0) || !z_t.is_contiguous() ||
110 !z_t.device().is_cuda()) {
111 return false;
112 }
113 }
114 }
115 return true;
116 };
117
118 // Convert _convolution and in-place operators for simpler replacement pattern
119 // matching
120 graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
121
122 rewriter.runOnGraph(graph, filter);
123 GRAPH_DEBUG("After fuseFrozenConvAddReluImpl: ", *graph);
124#endif
125}
126
127auto dummyInitializer = []() {
128 getFuseFrozenConvAddReluImpl() = fuseFrozenConvAddReluImpl;
129 return true;
130}();
131
132} // namespace
133
134} // namespace jit
135} // namespace torch
136