1 | #include <ATen/Config.h> |
2 | #include <ATen/code_template.h> |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | #include <torch/csrc/jit/jit_log.h> |
5 | #include <torch/csrc/jit/passes/constant_propagation.h> |
6 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
7 | #include <torch/csrc/jit/passes/graph_rewrite_helper.h> |
8 | #include <torch/csrc/jit/passes/mkldnn_rewrite.h> |
9 | #include <torch/csrc/jit/tensorexpr/kernel.h> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | |
14 | #if AT_MKLDNN_ENABLED() |
15 | |
16 | c10::VaryingShape<int64_t> getSizesOf(Node* n, size_t idx) { |
17 | auto tt = n->input(idx)->type()->cast<TensorType>(); |
18 | return tt->sizes(); |
19 | } |
20 | |
21 | void insertPrePackedConvOpForNode(Node* n) { |
22 | constexpr int POS_INPUT = 0; |
23 | constexpr int POS_WEIGHT = 1; |
24 | if (!tensorexpr::isContiguous( |
25 | n->input(POS_INPUT), at::MemoryFormat::ChannelsLast)) { |
26 | GRAPH_DEBUG( |
27 | "insertPrePackedConvOpForNode: input is not ChannelsLast contiguous" ); |
28 | return; |
29 | } |
30 | |
31 | if (!tensorexpr::isContiguous( |
32 | n->input(POS_WEIGHT), at::MemoryFormat::ChannelsLast)) { |
33 | GRAPH_DEBUG( |
34 | "insertPrePackedConvOpForNode: weight is not ChannelsLast contiguous" ); |
35 | return; |
36 | } |
37 | |
38 | // Leave depthwise conv2d to NNC |
39 | if (tensorexpr::conv2dIsSupportedJit(n)) { |
40 | GRAPH_DEBUG("insertPrePackedConvOpForNode: leave depthwise conv2d to NNC" ); |
41 | return; |
42 | } |
43 | |
44 | WithInsertPoint guard(n); |
45 | auto graph = n->owningGraph(); |
46 | |
47 | auto input_sizes = getSizesOf(n, POS_INPUT); |
48 | IValue input_size_value(*input_sizes.concrete_sizes()); |
49 | auto input_size = graph->insertConstant(input_size_value); |
50 | |
51 | auto prepack_node = graph->create( |
52 | Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack" ), 1); |
53 | |
54 | // skip input value |
55 | for (auto i = 1; i < n->inputs().size(); i++) { |
56 | Value* v = n->input(i); |
57 | prepack_node->addInput(v); |
58 | } |
59 | prepack_node->addInput(input_size); |
60 | auto attr = graph->insertConstant(IValue("none" )); |
61 | prepack_node->addInput(attr); |
62 | prepack_node->output()->setType( |
63 | getCustomClass("__torch__.torch.classes.mkldnn.ConvOpContext" )); |
64 | graph->insertNode(prepack_node); |
65 | |
66 | auto prepack_conv = graph->insertNode( |
67 | graph->create(Symbol::fromQualString("mkldnn_prepacked::conv2d_run" ), 1)); |
68 | prepack_conv->addInput(n->input(0)); |
69 | prepack_conv->addInput(prepack_node->output()); |
70 | prepack_conv->output()->setType(n->output()->type()->cast<TensorType>()); |
71 | |
72 | n->output()->replaceAllUsesWith(prepack_conv->output()); |
73 | } |
74 | |
75 | bool isTensorTypeCPU(Node* node) { |
76 | for (const auto& input : node->inputs()) { |
77 | auto type = input->type()->cast<TensorType>(); |
78 | if (!type) { |
79 | continue; |
80 | } |
81 | auto device = type->device(); |
82 | if (!device) { |
83 | return false; |
84 | } |
85 | if (!device->is_cpu()) { |
86 | return false; |
87 | } |
88 | } |
89 | return true; |
90 | } |
91 | |
92 | void insertPrePackedConvOp(Block* b) { |
93 | for (Node* n : b->nodes()) { |
94 | for (Block* b : n->blocks()) { |
95 | insertPrePackedConvOp(b); |
96 | } |
97 | |
98 | if (n->kind() == aten::conv2d) { |
99 | if (isTensorTypeCPU(n)) { |
100 | insertPrePackedConvOpForNode(n); |
101 | } |
102 | } |
103 | } |
104 | EliminateDeadCode(b); |
105 | } |
106 | |
107 | void insertMkldnnPrePackedConv2dOp(std::shared_ptr<Graph>& graph) { |
108 | insertPrePackedConvOp(graph->block()); |
109 | } |
110 | |
111 | void insertMkldnnPrePackedOps(std::shared_ptr<Graph>& graph) { |
112 | insertMkldnnPrePackedConv2dOp(graph); |
113 | } |
114 | |
115 | void insertMkldnnPrePackedOps(script::Module& module) { |
116 | for (auto& method : module.get_methods()) { |
117 | auto graph = method.graph(); |
118 | insertMkldnnPrePackedOps(graph); |
119 | } |
120 | for (script::Module m : module.children()) { |
121 | insertMkldnnPrePackedOps(m); |
122 | } |
123 | } |
124 | |
125 | void FuseReluWithPackedOps(std::shared_ptr<Graph>& graph) { |
126 | auto conv_op_rstring = at::jit::CodeTemplate(R"( |
127 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
128 | %dilation:int[], %groups:int, %input_size:int[], %dummy_attr:str): |
129 | %packed_weight_bias = mkldnn_prepacked::conv2d_prepack( |
130 | %weight, %bias, %stride, %padding, %dilation, %groups, |
131 | %input_size, %dummy_attr) |
132 | %conv2d_res = mkldnn_prepacked::conv2d_run(%input, %packed_weight_bias) |
133 | %res = aten::${op}(%conv2d_res) |
134 | return (%res))" ); |
135 | |
136 | auto conv_op_fused_rstring = at::jit::CodeTemplate(R"( |
137 | graph(%input, %weight, %bias, %stride:int[], %padding:int[], |
138 | %dilation:int[], %groups:int, %input_size:int[], %dummy_attr:str): |
139 | %attr: str = prim::Constant[value="${op_attr}"]() |
140 | %packed_weight_bias : __torch__.torch.classes.mkldnn.ConvOpContext = mkldnn_prepacked::conv2d_prepack( |
141 | %weight, %bias, %stride, %padding, %dilation, %groups, |
142 | %input_size, %attr) |
143 | %res = mkldnn_prepacked::conv2d_run(%input, %packed_weight_bias) |
144 | return (%res))" ); |
145 | |
146 | for (auto const& it : mkldnn::fusion_rewrite_map) { |
147 | std::string op = it.first; |
148 | if (op == std::string("none" )) { |
149 | continue; |
150 | } |
151 | |
152 | at::jit::TemplateEnv env; |
153 | env.s("op" , op); |
154 | |
155 | at::jit::TemplateEnv env_fused; |
156 | env_fused.s("op_attr" , op); |
157 | |
158 | SubgraphRewriter rewriter; |
159 | rewriter.RegisterRewritePattern( |
160 | conv_op_rstring.format(env), conv_op_fused_rstring.format(env_fused)); |
161 | |
162 | auto filters = it.second; |
163 | rewriter.runOnGraph(graph, filters); |
164 | } |
165 | } |
166 | |
167 | void PrePackingOpsFolder(Block* b) { |
168 | auto is_foldable_op = [](const Node* n) -> bool { |
169 | return ( |
170 | n->kind() == |
171 | Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack" )); |
172 | }; |
173 | |
174 | std::unordered_set<Node*> nodes_to_delete; |
175 | for (Node* n : b->nodes()) { |
176 | for (Block* block : n->blocks()) { |
177 | PrePackingOpsFolder(block); |
178 | } |
179 | if (is_foldable_op(n)) { |
180 | auto optional_outputs = torch::jit::runNodeIfInputsAreConstant(n); |
181 | if (optional_outputs) { |
182 | auto outputs = optional_outputs.value(); |
183 | TORCH_CHECK(outputs.size() == 1, "Prepack ops have single output" ); |
184 | Value* prepack_op_value = n->output(0); |
185 | auto graph = n->owningGraph(); |
186 | WithInsertPoint ins(prepack_op_value->node()); |
187 | auto weak_class_obj = |
188 | outputs[0].toObject()->copy_to_weak_compilation_ref(); |
189 | Value* packed_weight = graph->insertConstant(weak_class_obj) |
190 | ->setType(n->output(0)->type()); |
191 | prepack_op_value->replaceAllUsesWith(packed_weight); |
192 | nodes_to_delete.insert(n); |
193 | } |
194 | } |
195 | } |
196 | for (auto n : nodes_to_delete) { |
197 | n->removeAllInputs(); |
198 | } |
199 | for (auto n : nodes_to_delete) { |
200 | n->destroy(); |
201 | } |
202 | } |
203 | |
204 | void FoldPrePackingOps(std::shared_ptr<Graph>& graph) { |
205 | PrePackingOpsFolder(graph->block()); |
206 | } |
207 | |
208 | void FuseConvWithEltwise(std::shared_ptr<Graph>& graph) { |
209 | GRAPH_DEBUG( |
210 | "Before insertMkldnnPrePackedOps. Beginning of FuseConvWithEltwise\n" , |
211 | *graph); |
212 | insertMkldnnPrePackedOps(graph); |
213 | GRAPH_DEBUG( |
214 | "After insertMkldnnPrePackedOps, before FuseReluWithPackedOps\n" , *graph); |
215 | FuseReluWithPackedOps(graph); |
216 | GRAPH_DEBUG( |
217 | "After FuseReluWithPackedOps, before FoldPrePackingOps\n" , *graph); |
218 | FoldPrePackingOps(graph); |
219 | GRAPH_DEBUG("After FoldPrePackingOps. End of FuseConvWithEltwise\n" , *graph); |
220 | } |
221 | |
222 | #else |
223 | |
224 | void FuseConvWithEltwise(std::shared_ptr<Graph>& graph) { |
225 | GRAPH_DEBUG("MKLDNN Not enabled" ); |
226 | } |
227 | |
228 | #endif // AT_MKLDNN_ENABLED() |
229 | |
230 | } // namespace jit |
231 | } // namespace torch |
232 | |