1#include <ATen/Config.h>
2#include <ATen/code_template.h>
3#include <torch/csrc/jit/ir/ir.h>
4#include <torch/csrc/jit/jit_log.h>
5#include <torch/csrc/jit/passes/constant_propagation.h>
6#include <torch/csrc/jit/passes/dead_code_elimination.h>
7#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
8#include <torch/csrc/jit/passes/mkldnn_rewrite.h>
9#include <torch/csrc/jit/tensorexpr/kernel.h>
10
11namespace torch {
12namespace jit {
13
14#if AT_MKLDNN_ENABLED()
15
16c10::VaryingShape<int64_t> getSizesOf(Node* n, size_t idx) {
17 auto tt = n->input(idx)->type()->cast<TensorType>();
18 return tt->sizes();
19}
20
21void insertPrePackedConvOpForNode(Node* n) {
22 constexpr int POS_INPUT = 0;
23 constexpr int POS_WEIGHT = 1;
24 if (!tensorexpr::isContiguous(
25 n->input(POS_INPUT), at::MemoryFormat::ChannelsLast)) {
26 GRAPH_DEBUG(
27 "insertPrePackedConvOpForNode: input is not ChannelsLast contiguous");
28 return;
29 }
30
31 if (!tensorexpr::isContiguous(
32 n->input(POS_WEIGHT), at::MemoryFormat::ChannelsLast)) {
33 GRAPH_DEBUG(
34 "insertPrePackedConvOpForNode: weight is not ChannelsLast contiguous");
35 return;
36 }
37
38 // Leave depthwise conv2d to NNC
39 if (tensorexpr::conv2dIsSupportedJit(n)) {
40 GRAPH_DEBUG("insertPrePackedConvOpForNode: leave depthwise conv2d to NNC");
41 return;
42 }
43
44 WithInsertPoint guard(n);
45 auto graph = n->owningGraph();
46
47 auto input_sizes = getSizesOf(n, POS_INPUT);
48 IValue input_size_value(*input_sizes.concrete_sizes());
49 auto input_size = graph->insertConstant(input_size_value);
50
51 auto prepack_node = graph->create(
52 Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack"), 1);
53
54 // skip input value
55 for (auto i = 1; i < n->inputs().size(); i++) {
56 Value* v = n->input(i);
57 prepack_node->addInput(v);
58 }
59 prepack_node->addInput(input_size);
60 auto attr = graph->insertConstant(IValue("none"));
61 prepack_node->addInput(attr);
62 prepack_node->output()->setType(
63 getCustomClass("__torch__.torch.classes.mkldnn.ConvOpContext"));
64 graph->insertNode(prepack_node);
65
66 auto prepack_conv = graph->insertNode(
67 graph->create(Symbol::fromQualString("mkldnn_prepacked::conv2d_run"), 1));
68 prepack_conv->addInput(n->input(0));
69 prepack_conv->addInput(prepack_node->output());
70 prepack_conv->output()->setType(n->output()->type()->cast<TensorType>());
71
72 n->output()->replaceAllUsesWith(prepack_conv->output());
73}
74
75bool isTensorTypeCPU(Node* node) {
76 for (const auto& input : node->inputs()) {
77 auto type = input->type()->cast<TensorType>();
78 if (!type) {
79 continue;
80 }
81 auto device = type->device();
82 if (!device) {
83 return false;
84 }
85 if (!device->is_cpu()) {
86 return false;
87 }
88 }
89 return true;
90}
91
92void insertPrePackedConvOp(Block* b) {
93 for (Node* n : b->nodes()) {
94 for (Block* b : n->blocks()) {
95 insertPrePackedConvOp(b);
96 }
97
98 if (n->kind() == aten::conv2d) {
99 if (isTensorTypeCPU(n)) {
100 insertPrePackedConvOpForNode(n);
101 }
102 }
103 }
104 EliminateDeadCode(b);
105}
106
107void insertMkldnnPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
108 insertPrePackedConvOp(graph->block());
109}
110
111void insertMkldnnPrePackedOps(std::shared_ptr<Graph>& graph) {
112 insertMkldnnPrePackedConv2dOp(graph);
113}
114
115void insertMkldnnPrePackedOps(script::Module& module) {
116 for (auto& method : module.get_methods()) {
117 auto graph = method.graph();
118 insertMkldnnPrePackedOps(graph);
119 }
120 for (script::Module m : module.children()) {
121 insertMkldnnPrePackedOps(m);
122 }
123}
124
125void FuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
126 auto conv_op_rstring = at::jit::CodeTemplate(R"(
127 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
128 %dilation:int[], %groups:int, %input_size:int[], %dummy_attr:str):
129 %packed_weight_bias = mkldnn_prepacked::conv2d_prepack(
130 %weight, %bias, %stride, %padding, %dilation, %groups,
131 %input_size, %dummy_attr)
132 %conv2d_res = mkldnn_prepacked::conv2d_run(%input, %packed_weight_bias)
133 %res = aten::${op}(%conv2d_res)
134 return (%res))");
135
136 auto conv_op_fused_rstring = at::jit::CodeTemplate(R"(
137 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
138 %dilation:int[], %groups:int, %input_size:int[], %dummy_attr:str):
139 %attr: str = prim::Constant[value="${op_attr}"]()
140 %packed_weight_bias : __torch__.torch.classes.mkldnn.ConvOpContext = mkldnn_prepacked::conv2d_prepack(
141 %weight, %bias, %stride, %padding, %dilation, %groups,
142 %input_size, %attr)
143 %res = mkldnn_prepacked::conv2d_run(%input, %packed_weight_bias)
144 return (%res))");
145
146 for (auto const& it : mkldnn::fusion_rewrite_map) {
147 std::string op = it.first;
148 if (op == std::string("none")) {
149 continue;
150 }
151
152 at::jit::TemplateEnv env;
153 env.s("op", op);
154
155 at::jit::TemplateEnv env_fused;
156 env_fused.s("op_attr", op);
157
158 SubgraphRewriter rewriter;
159 rewriter.RegisterRewritePattern(
160 conv_op_rstring.format(env), conv_op_fused_rstring.format(env_fused));
161
162 auto filters = it.second;
163 rewriter.runOnGraph(graph, filters);
164 }
165}
166
167void PrePackingOpsFolder(Block* b) {
168 auto is_foldable_op = [](const Node* n) -> bool {
169 return (
170 n->kind() ==
171 Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack"));
172 };
173
174 std::unordered_set<Node*> nodes_to_delete;
175 for (Node* n : b->nodes()) {
176 for (Block* block : n->blocks()) {
177 PrePackingOpsFolder(block);
178 }
179 if (is_foldable_op(n)) {
180 auto optional_outputs = torch::jit::runNodeIfInputsAreConstant(n);
181 if (optional_outputs) {
182 auto outputs = optional_outputs.value();
183 TORCH_CHECK(outputs.size() == 1, "Prepack ops have single output");
184 Value* prepack_op_value = n->output(0);
185 auto graph = n->owningGraph();
186 WithInsertPoint ins(prepack_op_value->node());
187 auto weak_class_obj =
188 outputs[0].toObject()->copy_to_weak_compilation_ref();
189 Value* packed_weight = graph->insertConstant(weak_class_obj)
190 ->setType(n->output(0)->type());
191 prepack_op_value->replaceAllUsesWith(packed_weight);
192 nodes_to_delete.insert(n);
193 }
194 }
195 }
196 for (auto n : nodes_to_delete) {
197 n->removeAllInputs();
198 }
199 for (auto n : nodes_to_delete) {
200 n->destroy();
201 }
202}
203
204void FoldPrePackingOps(std::shared_ptr<Graph>& graph) {
205 PrePackingOpsFolder(graph->block());
206}
207
208void FuseConvWithEltwise(std::shared_ptr<Graph>& graph) {
209 GRAPH_DEBUG(
210 "Before insertMkldnnPrePackedOps. Beginning of FuseConvWithEltwise\n",
211 *graph);
212 insertMkldnnPrePackedOps(graph);
213 GRAPH_DEBUG(
214 "After insertMkldnnPrePackedOps, before FuseReluWithPackedOps\n", *graph);
215 FuseReluWithPackedOps(graph);
216 GRAPH_DEBUG(
217 "After FuseReluWithPackedOps, before FoldPrePackingOps\n", *graph);
218 FoldPrePackingOps(graph);
219 GRAPH_DEBUG("After FoldPrePackingOps. End of FuseConvWithEltwise\n", *graph);
220}
221
222#else
223
224void FuseConvWithEltwise(std::shared_ptr<Graph>& graph) {
225 GRAPH_DEBUG("MKLDNN Not enabled");
226}
227
228#endif // AT_MKLDNN_ENABLED()
229
230} // namespace jit
231} // namespace torch
232