1 | #include <ATen/ATen.h> |
2 | #include <ATen/Config.h> |
3 | #include <ATen/Utils.h> |
4 | #include <ATen/core/symbol.h> |
5 | #include <ATen/native/layer_norm.h> |
6 | #include <c10/core/ScalarType.h> |
7 | #include <c10/util/Exception.h> |
8 | #include <c10/util/irange.h> |
9 | |
10 | #include <torch/csrc/jit/ir/alias_analysis.h> |
11 | #include <torch/csrc/jit/ir/constants.h> |
12 | #include <torch/csrc/jit/ir/ir.h> |
13 | #include <torch/csrc/jit/jit_log.h> |
14 | #include <torch/csrc/jit/passes/common_subexpression_elimination.h> |
15 | #include <torch/csrc/jit/passes/constant_propagation.h> |
16 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
17 | #include <torch/csrc/jit/passes/fold_conv_bn.h> |
18 | #include <torch/csrc/jit/passes/frozen_conv_folding.h> |
19 | #include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h> |
20 | #include <torch/csrc/jit/passes/graph_rewrite_helper.h> |
21 | #include <torch/csrc/jit/passes/peephole.h> |
22 | #include <torch/csrc/jit/passes/remove_mutation.h> |
23 | #include <torch/csrc/jit/passes/utils/subgraph_utils.h> |
24 | #include <torch/csrc/jit/runtime/custom_operator.h> |
25 | #include <torch/csrc/jit/runtime/operator_options.h> |
26 | #include <torch/csrc/jit/tensorexpr/types.h> |
27 | // clang-format off |
28 | // moving ConvUtils include induces import cycle |
29 | #include <ATen/native/ConvUtils.h> |
30 | #include <algorithm> |
31 | #include <memory> |
32 | #include <ATen/core/stack.h> |
33 | #include <c10/core/Layout.h> |
34 | #include <c10/util/StringUtil.h> |
35 | |
36 | #if AT_MKLDNN_ENABLED() |
37 | #include <ATen/CPUFunctions.h> |
38 | #include <dnnl_types.h> |
39 | #include <ATen/native/mkldnn/Utils.h> |
40 | #include <ATen/native/mkldnn/MKLDNNCommon.h> |
41 | #include <ideep.hpp> |
42 | #endif |
43 | |
44 | // clang-format on |
45 | |
46 | namespace torch { |
47 | namespace jit { |
48 | |
49 | #if AT_MKLDNN_ENABLED() |
50 | |
51 | using Tensor = at::Tensor; |
52 | |
53 | namespace { |
54 | |
55 | c10::AliasAnalysisKind aliasAnalysisFromSchema() { |
56 | return AliasAnalysisKind::FROM_SCHEMA; |
57 | } |
58 | |
59 | using ValueSet = std::unordered_set<Value*>; |
60 | using ValueSetPtr = std::shared_ptr<std::unordered_set<Value*>>; |
61 | |
62 | Node* getLastUse(Value* v) { |
63 | auto last_use_node = v->node(); |
64 | for (const auto& use : v->uses()) { |
65 | if (use.user->isAfter(last_use_node)) { |
66 | last_use_node = use.user; |
67 | } |
68 | } |
69 | return last_use_node; |
70 | } |
71 | |
72 | void merge_sets( |
73 | std::unordered_map<Value*, ValueSetPtr>& alias_mapping, |
74 | Value* existing, |
75 | Value* new_v) { |
76 | if (alias_mapping[existing] == alias_mapping[new_v]) { |
77 | return; |
78 | } |
79 | auto existing_set = alias_mapping[existing]; |
80 | auto set_to_remove = alias_mapping[new_v]; |
81 | for (auto it = set_to_remove->begin(); it != set_to_remove->end(); it++) { |
82 | existing_set->insert(*it); |
83 | alias_mapping[*it] = existing_set; |
84 | } |
85 | } |
86 | |
87 | // no uses of tensors in container types |
88 | void assertNonTensorTypeDoesNotContainTensors(TypePtr type) { |
89 | if (type->cast<TensorType>()) { |
90 | return; |
91 | } |
92 | for (const auto& t : type->containedTypes()) { |
93 | TORCH_INTERNAL_ASSERT(!t->cast<TensorType>()); |
94 | } |
95 | } |
96 | |
97 | void InplaceMKLDNNSubgraph(std::shared_ptr<Graph> graph) { |
98 | // This function first calculates aliasing sets, |
99 | // then calculates the last node each aliasing set is alive for. |
100 | // Then we go through each node, if it's a node which has an equivalent |
101 | // inplace node and the aliasing set for its input is dead afer this node, we |
102 | // inplace it. Then we merge the aliasing sets for the input and output of the |
103 | // node and extend the liveness of the set. To inplace a node you need to |
104 | // prove device and dtype of the input and output are the same, which we've |
105 | // already done, and prove that the output size is the same as the input size, |
106 | // which is achieved by explicit Broadcast nodes (which we inserted for other |
107 | // reasons). |
108 | // The graphs here are simple subgraphs without uses of Tensors in |
109 | // containers (Lists, GetAttrs, etc) |
110 | |
111 | // CALCULATE ALIASING SETS |
112 | |
113 | auto aliasDb = torch::make_unique<AliasDb>(graph); |
114 | |
115 | // map from Value to its Aliasing Set |
116 | std::unordered_map<Value*, ValueSetPtr> alias_mapping; |
117 | ValueSet set; |
118 | ValueSetPtr input_set = std::make_shared<ValueSet>(set); |
119 | for (Value* v : graph->inputs()) { |
120 | if (v->type()->cast<TensorType>()) { |
121 | input_set->insert(v); |
122 | alias_mapping[v] = input_set; |
123 | } else { |
124 | assertNonTensorTypeDoesNotContainTensors(v->type()); |
125 | } |
126 | } |
127 | |
128 | for (Node* n : graph->nodes()) { |
129 | for (Value* output : n->outputs()) { |
130 | if (!output->type()->cast<TensorType>()) { |
131 | assertNonTensorTypeDoesNotContainTensors(output->type()); |
132 | continue; |
133 | } |
134 | |
135 | std::unordered_set<Value*> new_set = {output}; |
136 | alias_mapping[output] = std::make_shared<ValueSet>(new_set); |
137 | for (Value* input : n->inputs()) { |
138 | if (aliasDb->mayAlias(input, output)) { |
139 | merge_sets(alias_mapping, input, output); |
140 | } |
141 | } |
142 | } |
143 | } |
144 | |
145 | // CALCULATE ALIASING SET LIVENESS |
146 | |
147 | // map from aliased set -> last use of set |
148 | std::unordered_map<ValueSetPtr, Node*> set_liveness; |
149 | for (auto& set : alias_mapping) { |
150 | if (set_liveness.count(set.second)) { |
151 | continue; |
152 | } |
153 | Node* last = nullptr; |
154 | for (const auto& v : *set.second) { |
155 | auto k = v->node()->kind(); |
156 | if (k == prim::Constant || k == prim::ConstantMKLDNNTensor || |
157 | k == prim::Param) { |
158 | last = graph->return_node(); |
159 | continue; |
160 | } |
161 | |
162 | auto last_use = getLastUse(v); |
163 | if (!last || last_use->isAfter(last)) { |
164 | last = last_use; |
165 | } |
166 | } |
167 | set_liveness[set.second] = last; |
168 | } |
169 | |
170 | // REUSING MEMORY BY REINPLACING NODES |
171 | std::vector<Node*> nodes_to_inplace; |
172 | |
173 | auto add_to_inplace_set = [&](Node* node) { |
174 | // defer making the inplacing change because that would invalidate the old |
175 | // Node output Value* |
176 | nodes_to_inplace.push_back(node); |
177 | TORCH_INTERNAL_ASSERT(node->outputs().size() == 1); |
178 | auto output_liveness_end = |
179 | set_liveness[alias_mapping[node->outputs().at(0)]]; |
180 | merge_sets(alias_mapping, node->inputs().at(0), node->output()); |
181 | set_liveness[alias_mapping[node->output()]] = output_liveness_end; |
182 | }; |
183 | |
184 | for (Node* node : graph->nodes()) { |
185 | auto k = node->kind(); |
186 | if (k == aten::relu || k == aten::sigmoid || k == aten::dropout || |
187 | k == prim::MKLDNNHardSwish || k == prim::MKLDNNHardSigmoid || |
188 | k == prim::MKLDNNHardTanh || k == aten::tanh || |
189 | k == prim::MKLDNNClamp || k == Symbol::prim("MKLDNNScalarMul" ) || |
190 | k == Symbol::prim("MKLDNNLayerNorm" )) { |
191 | if (set_liveness[alias_mapping[node->inputs().at(0)]]->isAfter(node)) { |
192 | continue; |
193 | } |
194 | add_to_inplace_set(node); |
195 | } else if (k == aten::mul || k == aten::add) { |
196 | // the binary operators (add/mul) are commutative and only take tensor |
197 | // inputs, so we can inplace either the first or second input |
198 | int64_t reusable_value_index = -1; |
199 | for (const auto i : c10::irange(2)) { |
200 | TORCH_INTERNAL_ASSERT(node->inputs().at(i)->type()->cast<TensorType>()); |
201 | if (!set_liveness[alias_mapping[node->inputs().at(i)]]->isAfter(node)) { |
202 | reusable_value_index = i; |
203 | break; |
204 | } |
205 | } |
206 | |
207 | if (reusable_value_index == -1) { |
208 | continue; |
209 | } |
210 | |
211 | if (reusable_value_index == 1) { |
212 | node->insertInput(0, node->inputs().at(1)); |
213 | node->removeInput(2); |
214 | } |
215 | add_to_inplace_set(node); |
216 | } |
217 | } |
218 | |
219 | for (Node* node : nodes_to_inplace) { |
220 | node->replaceWithNewSymbol( |
221 | Symbol::fromQualString(node->schema().name() + "_" )); |
222 | node->destroy(); |
223 | } |
224 | } |
225 | |
226 | // This is a factory function that creates an Operation that that takes |
227 | // MKLDNN tensors and unpacks them into 1D contiguous tensors that we can |
228 | // run aten operations on. The precondition for using this function is that the |
229 | // aten operations in `aten_op` should be an identity for zero inputs. In other |
230 | // words, this should: `aten_op(0) = 0` The reason for this precondition has to |
231 | // do with blocked formats MKLDNN uses to lay tensor elements (nChw8c, nChw16c, |
232 | // etc). It splits the channel dimension into chunks of 8/16 makes it the |
233 | // innermost dimension. Whenever the channel dim isn't divisible by 8/16 the |
234 | // innermost dimension is padded with 0s. The precondition, `aten_op(0) == 0` |
235 | // allows us to avoid any special casing of padded elements. |
236 | Operation createUnaryOp( |
237 | std::function<void(at::Tensor output, at::Tensor input)> aten_op, |
238 | bool inplace = false) { |
239 | return [aten_op, inplace](Stack& stack) { |
240 | auto a = pop(stack).toTensor(); |
241 | c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); |
242 | // we cast `a` to an `ideep::tensor`, so we can get at its descriptor |
243 | // which we then use to set up `out` tensor w/ the same props as a |
244 | auto a_it = at::native::itensor_from_mkldnn(a); |
245 | auto mkldnn_raw_data = a_it.get_data_handle(); |
246 | auto a_options_with_strided = a.options().layout(c10::kStrided); |
247 | |
248 | // tensor's physical size could be bigger than a logical one |
249 | // `a_it.get_desc().get_size()` returns the real physical size (in bytes) |
250 | // we use it to compute `nelem` for `aten` ops |
251 | auto nelem = static_cast<int64_t>( |
252 | a_it.get_desc().get_size() / elementSize(a.scalar_type())); |
253 | // we also wrap `a` storage into an aten tensor |
254 | auto in_aten = |
255 | at::from_blob(mkldnn_raw_data, {nelem}, a_options_with_strided); |
256 | |
257 | auto out_raw_data = mkldnn_raw_data; |
258 | auto out = a; |
259 | if (!inplace) { |
260 | // `a_it.get_desc()` will allocate a tensor |
261 | // of the right physical size. |
262 | auto it_empty = ideep::tensor(a_it.get_desc()); |
263 | TORCH_INTERNAL_ASSERT(it_empty.get_desc() == a_it.get_desc()); |
264 | out = at::native::new_with_itensor_mkldnn( |
265 | std::move(it_empty), |
266 | optTypeMetaToScalarType(a.options().dtype_opt()), |
267 | a.options().device_opt()); |
268 | |
269 | out_raw_data = at::native::itensor_from_mkldnn(out).get_data_handle(); |
270 | } |
271 | |
272 | TORCH_INTERNAL_ASSERT( |
273 | a_it.get_desc().get_size() % elementSize(a.scalar_type()) == 0); |
274 | |
275 | auto out_aten = at::from_blob( |
276 | out_raw_data, {static_cast<int64_t>(nelem)}, a_options_with_strided); |
277 | aten_op(out_aten, in_aten); |
278 | push(stack, out); |
279 | }; |
280 | } |
281 | |
282 | void MKLDNNLayerNormOp(Stack& stack, bool inplace) { |
283 | c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); |
284 | |
285 | // enable_cudnn not used |
286 | pop(stack); |
287 | auto eps = pop(stack).toDouble(); |
288 | |
289 | Tensor bias{}; |
290 | Tensor weight{}; |
291 | auto bias_ival = pop(stack); |
292 | TORCH_INTERNAL_ASSERT(bias_ival.isTensor()); |
293 | bias = bias_ival.toTensor(); |
294 | |
295 | auto weight_ival = pop(stack); |
296 | TORCH_INTERNAL_ASSERT(weight_ival.isTensor()); |
297 | weight = weight_ival.toTensor(); |
298 | |
299 | auto shape = pop(stack).toDimVector(); |
300 | auto input = pop(stack).toTensor(); |
301 | |
302 | at::Tensor dst, mean, rstd; |
303 | std::tie(dst, mean, rstd) = |
304 | at::native::mkldnn_layer_norm_last_index_weight_bias_f32( |
305 | input, shape, weight, bias, eps, inplace); |
306 | push(stack, dst); |
307 | }; |
308 | |
309 | Operation BroadOp(const Node* node) { |
310 | return [](Stack& stack) { |
311 | auto b = pop(stack).toTensor(); |
312 | auto a = pop(stack).toTensor(); |
313 | auto b_size = b.sizes(); |
314 | auto a_size = a.sizes(); |
315 | if (a_size.equals(b_size)) { |
316 | // TODO: follow up with MKLDNN what the best way is |
317 | // to handle perf incompatible formats |
318 | push(stack, a, b); |
319 | return; |
320 | } else { |
321 | auto out_size = at::infer_size(a_size, b_size); |
322 | int64_t out_numel = out_size[0]; |
323 | for (size_t i = 1, end = out_size.size(); i < end; ++i) { |
324 | out_numel = out_numel * out_size[i]; |
325 | } |
326 | |
327 | auto exp_a = a; |
328 | auto exp_b = b; |
329 | int stacked = 0; |
330 | // mkldnn tensors only support reshape, not expand or view operators |
331 | if (a_size.equals(out_size)) { |
332 | push(stack, a); |
333 | ++stacked; |
334 | } else if (out_numel == a.numel()) { |
335 | exp_a = a.reshape(out_size); |
336 | } else { |
337 | // TODO: consider to initializing to a blocked layout |
338 | // directly if needed |
339 | exp_a = a.to_dense().expand(out_size).to_mkldnn(); |
340 | } |
341 | |
342 | if (b_size.equals(out_size)) { |
343 | push(stack, b); |
344 | ++stacked; |
345 | } else if (out_numel == b.numel()) { |
346 | exp_b = b.reshape(out_size); |
347 | } else { |
348 | exp_b = b.to_dense().expand(out_size).to_mkldnn(); |
349 | } |
350 | |
351 | if (stacked < 2) { |
352 | if (stacked == 1) { |
353 | pop(stack); |
354 | } |
355 | // If one of the inputs was expanded and converted to nchw/nhwc |
356 | // we might end up in a very bad spot if the second argument |
357 | // is in a blocked format. In this case, MKLDNN uses its |
358 | // reference implementation for a binary operation that follows |
359 | // these broadcasts and it could be up to ~100x slower. |
360 | // We use a very simple heuristic to convert an arg in nchw |
361 | // to the blocked format of the other argument. |
362 | c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); |
363 | auto a_it = at::native::itensor_from_mkldnn(exp_a); |
364 | auto b_it = at::native::itensor_from_mkldnn(exp_b); |
365 | |
366 | // `is_public_format` means a tensor's physical layout isn't in MKLDNN |
367 | // blocked layout e.g. nchw or nhwc but not nChw8c |
368 | if (!a_it.is_public_format()) { |
369 | if (b_it.is_public_format()) { |
370 | b_it = b_it.reorder_if_differ_in(a_it.get_desc()); |
371 | } |
372 | } else if (!b_it.is_public_format()) { |
373 | if (a_it.is_public_format()) { |
374 | a_it = a_it.reorder_if_differ_in(b_it.get_desc()); |
375 | } |
376 | } |
377 | |
378 | auto a_options = exp_a.options(); |
379 | auto a_out = at::native::new_with_itensor_mkldnn( |
380 | std::move(a_it), |
381 | optTypeMetaToScalarType(a_options.dtype_opt()), |
382 | a_options.device_opt()); |
383 | push(stack, a_out); |
384 | auto b_options = exp_b.options(); |
385 | auto b_out = at::native::new_with_itensor_mkldnn( |
386 | std::move(b_it), |
387 | optTypeMetaToScalarType(b_options.dtype_opt()), |
388 | b_options.device_opt()); |
389 | push(stack, b_out); |
390 | }; |
391 | } |
392 | }; |
393 | } |
394 | |
395 | static std::function<void(at::Tensor output, at::Tensor input)> hardtanh_helper( |
396 | const Node* n) { |
397 | auto min_val = n->f(attr::min_val); |
398 | auto max_val = n->f(attr::max_val); |
399 | return [min_val, max_val](at::Tensor output, at::Tensor input) { |
400 | at::cpu::hardtanh_out(output, input, min_val, max_val); |
401 | }; |
402 | } |
403 | |
404 | static std::function<void(at::Tensor output, at::Tensor input)> clamp_helper( |
405 | const Node* n) { |
406 | auto min_val = n->f(attr::min_val); |
407 | auto max_val = n->f(attr::max_val); |
408 | return [min_val, max_val](at::Tensor output, at::Tensor input) { |
409 | at::cpu::clamp_out(output, input, min_val, max_val); |
410 | }; |
411 | } |
412 | |
413 | // any op added to this registry needs to meet |
414 | // the precondition: `aten_op(0) == 0` |
415 | const RegisterOperators MKLDNNHardSwishOpReg({ |
416 | torch::jit::Operator( |
417 | "prim::MKLDNNHardSwish_(Tensor(a!) self) -> Tensor(a!)" , |
418 | createUnaryOp( |
419 | [](at::Tensor output, at::Tensor input) { |
420 | at::cpu::hardswish_out(output, input); |
421 | }, |
422 | true), |
423 | AliasAnalysisKind::FROM_SCHEMA), |
424 | torch::jit::Operator( |
425 | "prim::MKLDNNHardSigmoid_(Tensor(a!) self) -> Tensor(a!)" , |
426 | createUnaryOp( |
427 | [](at::Tensor output, at::Tensor input) { |
428 | at::cpu::hardsigmoid_out(output, input); |
429 | }, |
430 | true), |
431 | AliasAnalysisKind::FROM_SCHEMA), |
432 | torch::jit::Operator( |
433 | "prim::MKLDNNHardTanh_(Tensor(a!) self) -> Tensor(a!)" , |
434 | [](const Node* n) -> Operation { |
435 | return createUnaryOp(hardtanh_helper(n), true); |
436 | }, |
437 | AliasAnalysisKind::FROM_SCHEMA), |
438 | torch::jit::Operator( |
439 | "prim::MKLDNNClamp_(Tensor(a!) self) -> Tensor(a!)" , |
440 | [](const Node* n) -> Operation { |
441 | return createUnaryOp(clamp_helper(n), true); |
442 | }, |
443 | AliasAnalysisKind::FROM_SCHEMA), |
444 | torch::jit::Operator( |
445 | "prim::MKLDNNHardSwish(Tensor a) -> Tensor" , |
446 | createUnaryOp( |
447 | [](at::Tensor output, at::Tensor input) { |
448 | at::cpu::hardswish_out(output, input); |
449 | }, |
450 | false), |
451 | AliasAnalysisKind::FROM_SCHEMA), |
452 | torch::jit::Operator( |
453 | "prim::MKLDNNHardSigmoid(Tensor a) -> Tensor" , |
454 | createUnaryOp( |
455 | [](at::Tensor output, at::Tensor input) { |
456 | at::cpu::hardsigmoid_out(output, input); |
457 | }, |
458 | false), |
459 | AliasAnalysisKind::FROM_SCHEMA), |
460 | torch::jit::Operator( |
461 | "prim::MKLDNNHardTanh(Tensor self) -> Tensor" , |
462 | [](const Node* n) -> Operation { |
463 | return createUnaryOp(hardtanh_helper(n), false); |
464 | }, |
465 | AliasAnalysisKind::FROM_SCHEMA), |
466 | torch::jit::Operator( |
467 | "prim::MKLDNNClamp(Tensor self) -> Tensor" , |
468 | [](const Node* n) -> Operation { |
469 | return createUnaryOp(clamp_helper(n), false); |
470 | }, |
471 | AliasAnalysisKind::FROM_SCHEMA), |
472 | }); |
473 | |
474 | const RegisterOperators BroadOpReg({ |
475 | torch::jit::Operator( |
476 | prim::BroadcastMKLDNNTensors, |
477 | BroadOp, |
478 | AliasAnalysisKind::INTERNAL_SPECIAL_CASE), |
479 | }); |
480 | |
481 | const RegisterOperators MKLDNNLayerNormOpReg({ |
482 | torch::jit::Operator( |
483 | "prim::MKLDNNLayerNorm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor" , |
484 | [](Stack& stack) { MKLDNNLayerNormOp(stack, false); }, |
485 | AliasAnalysisKind::FROM_SCHEMA), |
486 | torch::jit::Operator( |
487 | "prim::MKLDNNLayerNorm_(Tensor(a!) input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor(a!)" , |
488 | [](Stack& stack) { MKLDNNLayerNormOp(stack, true); }, |
489 | AliasAnalysisKind::FROM_SCHEMA), |
490 | }); |
491 | |
492 | Operation ConstantMKLDNNTensorOp(const Node* node) { |
493 | const auto& t = node->t(attr::value); |
494 | return [t](Stack& stack) { |
495 | push(stack, t); |
496 | return 0; |
497 | }; |
498 | } |
499 | |
500 | Tensor mkldnn_tensor_scalar_mul(Tensor& tensor, Tensor& out, float scalar) { |
501 | ideep::tensor& x = at::native::itensor_from_mkldnn(tensor); |
502 | ideep::tensor& z = at::native::itensor_from_mkldnn(out); |
503 | ideep::eltwise_forward::compute( |
504 | x, |
505 | z, |
506 | ideep::algorithm::eltwise_linear, |
507 | ideep::prop_kind::forward_inference, |
508 | /*alpha*/ scalar); |
509 | return out; |
510 | } |
511 | |
512 | // aten::convolution does a lot of precomputation and dispatching before |
513 | // mkldnn_convolution is called. registering here we can directly invoke the op |
514 | // and avoid overhead. avoiding dispatch overhead for other operators - relu, |
515 | // add, etc - did not benchmark as speeding up models noticeably. the additional |
516 | // overhead of `convolution` warrants the custom operator. |
517 | jit::RegisterOperators reg_fut_ops({ |
518 | jit::Operator( |
519 | // XXX: this follows the schema convention of conv2d/conv3d, not |
520 | // aten::mkldnn_convolution, which is different for some reason! |
521 | "prim::mkldnn_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor" , |
522 | [](jit::Stack& stack) { |
523 | int64_t groups = pop(stack).toInt(); |
524 | auto dilation = pop(stack).toIntVector(); |
525 | auto padding = pop(stack).toIntVector(); |
526 | auto stride = pop(stack).toIntVector(); |
527 | |
528 | Tensor bias; |
529 | IValue bias_ival = pop(stack); |
530 | if (!bias_ival.isNone()) { |
531 | bias = bias_ival.toTensor(); |
532 | } |
533 | Tensor weight = pop(stack).toTensor(); |
534 | Tensor input = pop(stack).toTensor(); |
535 | |
536 | at::AutoDispatchBelowAutograd mode; |
537 | // aten::convolution takes care of 0 dim case before calls into |
538 | // backends |
539 | if (input.size(0) == 0) { |
540 | std::vector<int64_t> o = at::native::conv_output_size( |
541 | input.sizes(), weight.sizes(), padding, stride, dilation); |
542 | push( |
543 | stack, |
544 | at::native::empty_mkldnn( |
545 | o, |
546 | optTypeMetaToScalarType(input.options().dtype_opt()), |
547 | input.options().layout_opt(), |
548 | input.options().device_opt(), |
549 | input.options().pinned_memory_opt())); |
550 | return; |
551 | } |
552 | // aten::convolution also checks dtype mismatches |
553 | TORCH_CHECK( |
554 | input.options().type_equal(weight.options()), |
555 | "Input type (" , |
556 | input.toString(), |
557 | ") and weight type (" , |
558 | weight.toString(), |
559 | ") should be the same" ); |
560 | |
561 | push( |
562 | stack, |
563 | at::native::mkldnn_convolution( |
564 | input, weight, bias, padding, stride, dilation, groups)); |
565 | }, |
566 | aliasAnalysisFromSchema()), |
567 | // registering as custom operators avoids Scalar->Tensor->Scalar conversion |
568 | // in default bindings |
569 | jit::Operator( |
570 | "prim::MKLDNNScalarMul(Tensor self, Scalar other) -> Tensor" , |
571 | [](jit::Stack& stack) { |
572 | c10::impl::ExcludeDispatchKeyGuard edkg( |
573 | c10::autograd_dispatch_keyset); |
574 | float other = pop(stack).toScalar().toFloat(); |
575 | Tensor self = pop(stack).toTensor(); |
576 | auto out = at::native::empty_mkldnn( |
577 | self.sizes(), |
578 | optTypeMetaToScalarType(self.options().dtype_opt()), |
579 | self.options().layout_opt(), |
580 | self.options().device_opt(), |
581 | self.options().pinned_memory_opt()); |
582 | |
583 | mkldnn_tensor_scalar_mul(self, out, other); |
584 | push(stack, out); |
585 | }, |
586 | aliasAnalysisFromSchema()), |
587 | jit::Operator( |
588 | "prim::MKLDNNScalarMul_(Tensor(a!) self, Scalar other) -> Tensor(a!)" , |
589 | [](jit::Stack& stack) { |
590 | c10::impl::ExcludeDispatchKeyGuard edkg( |
591 | c10::autograd_dispatch_keyset); |
592 | float other = pop(stack).toScalar().toFloat(); |
593 | Tensor self = pop(stack).toTensor(); |
594 | mkldnn_tensor_scalar_mul(self, self, other); |
595 | push(stack, self); |
596 | }, |
597 | aliasAnalysisFromSchema()), |
598 | }); |
599 | |
600 | // This is registered as its own op instead of as prim::Constant bc it does not |
601 | // serialize which is an invariant of prim::Constant |
602 | // TODO: make mkldnn tensor serialize... |
603 | const RegisterOperators MKLDNNConstantOp({ |
604 | torch::jit::Operator( |
605 | prim::ConstantMKLDNNTensor, |
606 | ConstantMKLDNNTensorOp, |
607 | AliasAnalysisKind::INTERNAL_SPECIAL_CASE), |
608 | }); |
609 | |
610 | Node* createConstantMKLDNNTensorOp(Graph* g, const Tensor& mkldnn_tensor) { |
611 | TORCH_INTERNAL_ASSERT(mkldnn_tensor.is_mkldnn()); |
612 | auto op = g->create(prim::ConstantMKLDNNTensor); |
613 | op->t_(attr::value, mkldnn_tensor); |
614 | return op; |
615 | } |
616 | |
617 | bool supportedMKLDNNWeight(const Tensor& weight) { |
618 | return weight.device().is_cpu() && weight.dtype() == c10::ScalarType::Float && |
619 | weight.ndimension() != 0; |
620 | } |
621 | |
622 | void replaceInputWithMKLDNNTensor(Node* n, size_t index) { |
623 | Value* input = n->inputs().at(index); |
624 | auto mkldnn_tensor = constant_as<Tensor>(input)->to_mkldnn(); |
625 | auto mkldnn_tensor_value = |
626 | createConstantMKLDNNTensorOp(n->owningGraph(), mkldnn_tensor) |
627 | ->insertBefore(n) |
628 | ->output(); |
629 | mkldnn_tensor_value->setDebugName(input->debugName() + "_mkldnn" ); |
630 | n->replaceInputWith(input, mkldnn_tensor_value); |
631 | } |
632 | |
633 | void replaceInputWithMKLDNNTensor( |
634 | Node* n, |
635 | const std::string& name, |
636 | const at::Tensor& mkldnn_tensor) { |
637 | Value* input = n->namedInput(name); |
638 | auto mkldnn_tensor_value = |
639 | createConstantMKLDNNTensorOp(n->owningGraph(), mkldnn_tensor) |
640 | ->insertBefore(n) |
641 | ->output(); |
642 | mkldnn_tensor_value->setDebugName(input->debugName() + "_mkldnn" ); |
643 | n->replaceInputWith(input, mkldnn_tensor_value); |
644 | } |
645 | |
646 | void replaceInputWithMKLDNNTensor(Node* n, const std::string& name) { |
647 | Value* input = n->namedInput(name); |
648 | auto mkldnn_tensor = constant_as<Tensor>(input)->to_mkldnn(); |
649 | replaceInputWithMKLDNNTensor(n, name, mkldnn_tensor); |
650 | } |
651 | |
652 | void moveConvWeightsToMKLDNN(Node* conv) { |
653 | auto conv_w_mkldnn = |
654 | constant_as<Tensor>(conv->namedInput("weight" )).value().to_mkldnn(); |
655 | std::vector<int64_t> padding = |
656 | toIValue(conv->namedInput("padding" ))->toIntVector(); |
657 | std::vector<int64_t> stride = |
658 | toIValue(conv->namedInput("stride" ))->toIntVector(); |
659 | std::vector<int64_t> dilation = |
660 | toIValue(conv->namedInput("dilation" ))->toIntVector(); |
661 | auto groups = constant_as<int64_t>(conv->namedInput("groups" )).value(); |
662 | |
663 | if (conv->kind() == aten::conv2d) { |
664 | conv_w_mkldnn = mkldnn_reorder_conv2d_weight( |
665 | conv_w_mkldnn, padding, stride, dilation, groups); |
666 | } else if (conv->kind() == aten::conv3d) { |
667 | conv_w_mkldnn = mkldnn_reorder_conv3d_weight( |
668 | conv_w_mkldnn, padding, stride, dilation, groups); |
669 | } else { |
670 | TORCH_INTERNAL_ASSERT(false); |
671 | } |
672 | replaceInputWithMKLDNNTensor(conv, "weight" , conv_w_mkldnn); |
673 | |
674 | if (conv->namedInput("bias" )->type() != NoneType::get()) { |
675 | replaceInputWithMKLDNNTensor(conv, "bias" ); |
676 | } |
677 | } |
678 | |
679 | void moveWeightsToMKLDNN(Node* n) { |
680 | // conv goes through special pathway so we can call mkldnn reorder conv |
681 | // primitive |
682 | if (n->kind() == aten::conv2d || n->kind() == aten::conv3d) { |
683 | moveConvWeightsToMKLDNN(n); |
684 | } else { |
685 | for (size_t i = 0; i < n->inputs().size(); ++i) { |
686 | if (!n->input(i)->type()->cast<TensorType>() || |
687 | n->input(i)->node()->kind() != prim::Constant) { |
688 | continue; |
689 | } |
690 | replaceInputWithMKLDNNTensor(n, i); |
691 | } |
692 | } |
693 | } |
694 | |
695 | static void clamp_node_creator( |
696 | Node* body_node, |
697 | c10::Symbol kind, |
698 | double min_val, |
699 | double max_val) { |
700 | WithInsertPoint insert_guard{body_node}; |
701 | auto out_node = |
702 | body_node->owningGraph()->create({kind}, {body_node->input(0)}, 1); |
703 | // N.B. we can't use `insert` as it calls `getOperation` (via |
704 | // `emitBuiltinCall`) which uses `min_val` and `max_val` attrs which we |
705 | // haven't set yet. |
706 | body_node->owningGraph()->insertNode(out_node); |
707 | auto out_val = out_node->output(); |
708 | out_node->f_(attr::min_val, min_val); |
709 | out_node->f_(attr::max_val, max_val); |
710 | out_val->copyMetadata(body_node->output()); |
711 | body_node->output()->replaceAllUsesWith(out_val); |
712 | body_node->destroy(); |
713 | } |
714 | |
715 | void ComputeSubgraphInMKLDNN(Node* subgraph_node) { |
716 | auto graph = subgraph_node->owningGraph(); |
717 | Value* none_value = nullptr; |
718 | { |
719 | WithInsertPoint guard(subgraph_node); |
720 | none_value = graph->insertConstant(IValue()); |
721 | } |
722 | for (size_t i = 0; i < subgraph_node->inputs().size(); ++i) { |
723 | Value* v = subgraph_node->inputs().at(i); |
724 | if (!v->type()->cast<TensorType>()) { |
725 | continue; |
726 | } |
727 | auto to_mkldnn = |
728 | graph->create(c10::Symbol::fromQualString("aten::to_mkldnn" ), 1) |
729 | ->insertBefore(subgraph_node); |
730 | to_mkldnn->addInput(v); |
731 | to_mkldnn->addInput(none_value); |
732 | subgraph_node->replaceInput(i, to_mkldnn->output()); |
733 | } |
734 | |
735 | for (size_t i = 0; i < subgraph_node->outputs().size(); ++i) { |
736 | Value* v = subgraph_node->outputs().at(i); |
737 | if (!v->type()->cast<TensorType>()) { |
738 | continue; |
739 | } |
740 | auto from_mkldnn = |
741 | graph |
742 | ->create( |
743 | c10::Symbol::fromQualString("aten::to_dense" ), {v, none_value}) |
744 | ->insertAfter(subgraph_node); |
745 | v->replaceAllUsesAfterNodeWith(from_mkldnn, from_mkldnn->output()); |
746 | } |
747 | |
748 | auto subgraph = SubgraphUtils::getSubgraph(subgraph_node); |
749 | for (auto it = subgraph->block()->nodes().begin(); |
750 | it != subgraph->block()->nodes().end();) { |
751 | Node* body_node = *it; |
752 | it++; |
753 | |
754 | moveWeightsToMKLDNN(body_node); |
755 | |
756 | if (body_node->kind() == aten::add || |
757 | (body_node->kind() == aten::mul && |
758 | body_node->input(1)->type()->cast<TensorType>())) { |
759 | auto node = body_node->owningGraph()->create( |
760 | Symbol::prim("BroadcastMKLDNNTensors" ), |
761 | {body_node->inputs().at(0), body_node->inputs().at(1)}, |
762 | 2); |
763 | node->insertBefore(body_node); |
764 | body_node->replaceInput(0, node->outputs().at(0)); |
765 | body_node->replaceInput(1, node->outputs().at(1)); |
766 | } |
767 | if (body_node->kind() == aten::mul && |
768 | body_node->input(1)->type()->isSubtypeOf(*NumberType::get())) { |
769 | body_node->replaceWithNewSymbol(Symbol::prim("MKLDNNScalarMul" )); |
770 | body_node->destroy(); |
771 | continue; |
772 | } |
773 | |
774 | if (body_node->matches( |
775 | "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor" )) { |
776 | body_node->replaceWithNewSymbol(Symbol::prim("MKLDNNLayerNorm" )); |
777 | body_node->destroy(); |
778 | continue; |
779 | } |
780 | |
781 | if (body_node->kind() == aten::hardswish) { |
782 | body_node->replaceWithNewSymbol(prim::MKLDNNHardSwish); |
783 | body_node->destroy(); |
784 | continue; |
785 | } |
786 | |
787 | if (body_node->kind() == aten::hardsigmoid) { |
788 | body_node->replaceWithNewSymbol(prim::MKLDNNHardSigmoid); |
789 | body_node->destroy(); |
790 | continue; |
791 | } |
792 | |
793 | if (body_node->kind() == aten::relu6) { |
794 | clamp_node_creator(body_node, prim::MKLDNNHardTanh, 0., 6.); |
795 | continue; |
796 | } |
797 | |
798 | if (body_node->kind() == aten::hardtanh) { |
799 | auto min_val = |
800 | constant_as<double>(body_node->namedInput("min_val" )).value(); |
801 | auto max_val = |
802 | constant_as<double>(body_node->namedInput("max_val" )).value(); |
803 | clamp_node_creator(body_node, prim::MKLDNNHardTanh, min_val, max_val); |
804 | continue; |
805 | } |
806 | |
807 | if (body_node->kind() == aten::clamp) { |
808 | auto min_val = constant_as<double>(body_node->namedInput("min" )).value(); |
809 | auto max_val = constant_as<double>(body_node->namedInput("max" )).value(); |
810 | clamp_node_creator(body_node, prim::MKLDNNClamp, min_val, max_val); |
811 | continue; |
812 | } |
813 | |
814 | if (body_node->kind() == aten::conv2d || |
815 | body_node->kind() == aten::conv3d) { |
816 | // this node doesnt handle string padding yet... |
817 | if (!body_node->namedInput("padding" )->type()->cast<StringType>()) { |
818 | body_node->replaceWithNewSymbol(Symbol::prim("mkldnn_convolution" )); |
819 | body_node->destroy(); |
820 | continue; |
821 | } |
822 | } |
823 | } |
824 | } |
825 | |
826 | bool nonConstantParameters(Node* n) { |
827 | for (size_t i = 1; i < n->inputs().size(); i++) { |
828 | if (n->inputs().at(i)->node()->kind() != prim::Constant) { |
829 | return true; |
830 | } |
831 | } |
832 | return false; |
833 | } |
834 | |
835 | bool frozenMkldnnCompatibleLinearNode(Node* n) { |
836 | if (nonConstantParameters(n)) { |
837 | return false; |
838 | } |
839 | |
840 | if (n->kind() != aten::linear) { |
841 | return false; |
842 | } |
843 | |
844 | auto weight = constant_as<Tensor>(n->namedInput("weight" )).value(); |
845 | return supportedMKLDNNWeight(weight); |
846 | } |
847 | |
848 | bool frozenMkldnnCompatibleConvNode(Node* n) { |
849 | if (nonConstantParameters(n)) { |
850 | return false; |
851 | } |
852 | // mkldnn does not support conv1d |
853 | // _convolution is rewritten before this pass is invoked |
854 | if (n->kind() != aten::conv2d && n->kind() != aten::conv3d) { |
855 | return false; |
856 | } |
857 | |
858 | auto weight = constant_as<Tensor>(n->namedInput("weight" )).value(); |
859 | return supportedMKLDNNWeight(weight); |
860 | } |
861 | |
862 | // [mkldnn perf strategy] |
863 | // Certain ops - aten::linear, aten::conv2d, aten::conv3d - provide a huge speed |
864 | // up just by converting the constant weights to MKLDNN AOT, and then at runtime |
865 | // converting the non-constant input to_mkldnn before the op, and then back to |
866 | // its original layout after the op. The speed up holds even if you end up |
867 | // converting the input to_mkldnn and output back to_dense. We start groups of |
868 | // ops to compute in MKLDNN only from these ops that are a strict speedup. Then, |
869 | // we expand the groups to include operators which are computable in MKLDNN & |
870 | // are roughly perf equal to eager. We do this in the hopes of joining multiple |
871 | // fast nodes together, saving to_mkldnn and to_dense conversions. |
872 | // |
873 | // MKLDNN only supports float32 inputs for aten::linear, aten::conv2d & |
874 | // aten::conv3d. We only fuse these nodes if the weights are float32, and then |
875 | // we only include operators which we can prove will execute in float32. By |
876 | // fusing topologically we can maintain the invariant that all tensor types in |
877 | // the graph are floating point. In fusing Conv-> Add -> Relu -> Conv we start |
878 | // with the first Conv, know that the output is float, and can then safely merge |
879 | // Add and Relu. If we started with the last Conv, it would be difficult to |
880 | // prove in our first pass that the Add's inputs were both float32 without first |
881 | // fusing the first conv. |
882 | |
883 | class MKLDNNSubgraphSlicer { |
884 | public: |
885 | MKLDNNSubgraphSlicer( |
886 | Block* block, |
887 | std::shared_ptr<Graph> graph, |
888 | AliasDb& aliasDb) |
889 | : block_(block), graph_(std::move(graph)), aliasDb_(aliasDb) {} |
890 | |
891 | void run() { |
892 | // We maintain alias db correctness in-place while building up the autodiff |
893 | // subgraphs, however it is difficult to preserve correctness when |
894 | // un-inlining autodiff subgraphs. We first recursively construct all |
895 | // subgraphs and then unmerge them into the graph |
896 | buildupSubgraphs(); |
897 | computeSubgraphsInMKLDNN(); |
898 | // Run CSE globally onceto eliminate duplicates that may have occurred |
899 | // while inlining subgraphs. |
900 | EliminateCommonSubexpression(graph_); |
901 | } |
902 | |
903 | void buildupSubgraphs() { |
904 | // We need to run the slicer multiple times in order to get all merge |
905 | // opportunities. This is because moveBeforeTopologicalValid may reorder |
906 | // nodes to be AFTER the current iteration point. In order to properly |
907 | // consider those nodes for merging, we need run the pass until no changes |
908 | // have been made. |
909 | // |
910 | // Example: |
911 | // c = f(a, b) |
912 | // d = f(c) |
913 | // e = f(d) <- iter is here, moving upward |
914 | // After c.moveBeforeTopologicallyValid(e), we have: |
915 | // c = f(a, b) |
916 | // e = f(d) <- iter still here |
917 | // d = f(c) <- this was node moved on the other side. |
918 | |
919 | bool any_changed = true; |
920 | while (any_changed) { |
921 | any_changed = false; |
922 | for (auto it = block_->nodes().begin(); it != block_->nodes().end();) { |
923 | bool changed = false; |
924 | std::tie(it, changed) = scanNode(*it); |
925 | any_changed |= changed; |
926 | } |
927 | } |
928 | |
929 | // Construct Subgraphs Recursively |
930 | for (Node* n : block_->nodes()) { |
931 | for (auto subBlock : n->blocks()) { |
932 | MKLDNNSubgraphSlicer(subBlock, graph_, aliasDb_).buildupSubgraphs(); |
933 | } |
934 | } |
935 | } |
936 | |
937 | static bool MKLDNNGroupStart(Node* node) { |
938 | // if we're already in the process of merging |
939 | if (node->kind() == prim::MKLDNNGroup) { |
940 | return true; |
941 | } |
942 | // see [mkldnn perf strategy] |
943 | return frozenMkldnnCompatibleConvNode(node); |
944 | } |
945 | |
946 | private: |
947 | // MKLDNN only supports floats of dimension > 0, so we only support |
948 | // Tensors who have a known type or were previously verified |
949 | // to be usable in an MKLDNN Group |
950 | bool tensorInputIsMKLDNNSupported(Value* v, Node* v_use) { |
951 | auto const_tensor = constant_as<Tensor>(v); |
952 | if (const_tensor) { |
953 | return supportedMKLDNNWeight(*const_tensor); |
954 | } |
955 | auto k = v->node()->kind(); |
956 | if (k == prim::MKLDNNGroup || k == prim::ConstantMKLDNNTensor || |
957 | k == aten::to_mkldnn) { |
958 | return true; |
959 | } |
960 | for (const auto& use : v->uses()) { |
961 | if (use.user->kind() == aten::to_mkldnn && |
962 | v_use->owningBlock() == use.user->owningBlock()) { |
963 | return true; |
964 | } |
965 | } |
966 | return false; |
967 | } |
968 | |
969 | // We include ops here which are roughly perf-equivalent in mkldnn as with |
970 | // aten (single & multithreaded) and whose inputs & outputs are float32. |
971 | bool computableInMKLDNN(Node* n) { |
972 | for (Value* v : n->inputs()) { |
973 | if (v->type()->cast<TensorType>() && |
974 | !(tensorInputIsMKLDNNSupported(v, n))) { |
975 | return false; |
976 | } |
977 | } |
978 | |
979 | if (n->matches( |
980 | "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor" ) && |
981 | n->namedInput("weight" )->type() != NoneType::get() && |
982 | n->namedInput("bias" )->type() != NoneType::get()) { |
983 | auto norm_shape = |
984 | constant_as<std::vector<int64_t>>(n->namedInput("normalized_shape" )); |
985 | return norm_shape.has_value() && norm_shape->size() == 1; |
986 | } |
987 | |
988 | // unary ops we dont need to prove anything else than |
989 | // the input is mkldnn supported |
990 | switch (n->kind()) { |
991 | case aten::relu: |
992 | case aten::relu6: |
993 | case aten::gelu: |
994 | case aten::prelu: |
995 | case aten::sigmoid: |
996 | case aten::hardsigmoid: |
997 | case aten::hardswish: |
998 | case aten::tanh: |
999 | case aten::batch_norm: |
1000 | case aten::max_pool2d: |
1001 | case aten::max_pool3d: |
1002 | case aten::avg_pool2d: |
1003 | case aten::adaptive_avg_pool2d: |
1004 | case aten::avg_pool3d: |
1005 | // case aten::adaptive_max_pool2d: // return tuples which break fusion |
1006 | // case aten::adaptive_max_pool3d: // return tuples which break fusion |
1007 | // case aten::adaptive_avg_pool3d: // no ideep binding |
1008 | return true; |
1009 | } |
1010 | |
1011 | if ((n->kind() == aten::hardtanh || n->kind() == aten::clamp) && |
1012 | !nonConstantParameters(n)) { |
1013 | const size_t MIN_INDEX = 1, MAX_INDEX = 2; |
1014 | auto min_val = constant_as<double>(n->input(MIN_INDEX)).value(); |
1015 | auto max_val = constant_as<double>(n->input(MAX_INDEX)).value(); |
1016 | // we need to maintain the following invariant `pointwise_func(0) == 0`, |
1017 | // see `createUnaryOp` |
1018 | if (min_val <= 0. && max_val >= 0.) { |
1019 | return true; |
1020 | } |
1021 | } |
1022 | |
1023 | if (n->kind() == aten::add) { |
1024 | // mkldnn doesn't currently support Tensor-Scalar add |
1025 | for (const auto i : c10::irange(2)) { |
1026 | if (!n->inputs().at(i)->type()->cast<TensorType>()) { |
1027 | return false; |
1028 | } |
1029 | } |
1030 | return true; |
1031 | } |
1032 | if (n->kind() == aten::mul) { |
1033 | return n->input(0)->type()->cast<TensorType>() && |
1034 | (n->input(1)->type()->cast<TensorType>() || |
1035 | n->input(1)->type()->isSubtypeOf(*NumberType::get())); |
1036 | } |
1037 | |
1038 | if (n->kind() == aten::dropout) { |
1039 | auto train = constant_as<bool>(n->namedInput("train" )).value(); |
1040 | return train == false; |
1041 | } |
1042 | return false; |
1043 | } |
1044 | |
1045 | void computeSubgraphsInMKLDNN() { |
1046 | auto curNode = *block_->nodes().begin(); |
1047 | while (curNode != *block_->nodes().end()) { |
1048 | auto nextNode = curNode->next(); |
1049 | if (curNode->kind() == prim::MKLDNNGroup) { |
1050 | ComputeSubgraphInMKLDNN(curNode); |
1051 | InplaceMKLDNNSubgraph(SubgraphUtils::getSubgraph(curNode)); |
1052 | SubgraphUtils::unmergeSubgraph(curNode); |
1053 | } |
1054 | curNode = nextNode; |
1055 | } |
1056 | for (Node* n : block_->nodes()) { |
1057 | for (Block* b : n->blocks()) { |
1058 | MKLDNNSubgraphSlicer(b, graph_, aliasDb_).computeSubgraphsInMKLDNN(); |
1059 | } |
1060 | } |
1061 | } |
1062 | |
1063 | bool shouldConsiderForMerge(Node* node) { |
1064 | // if we're already in the process of merging |
1065 | if (node->kind() == prim::MKLDNNGroup) { |
1066 | return true; |
1067 | } |
1068 | return frozenMkldnnCompatibleLinearNode(node) || |
1069 | frozenMkldnnCompatibleConvNode(node) || computableInMKLDNN(node); |
1070 | } |
1071 | |
1072 | std::pair<graph_node_list::iterator, bool> scanNode(Node* producer) { |
1073 | if (MKLDNNGroupStart(producer)) { |
1074 | if (producer->kind() != prim::MKLDNNGroup) { |
1075 | producer = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing( |
1076 | producer, prim::MKLDNNGroup, aliasDb_); |
1077 | } |
1078 | std::vector<Node*> output_nodes; |
1079 | for (Value* v : producer->outputs()) { |
1080 | for (const auto& use : v->uses()) { |
1081 | output_nodes.push_back(use.user); |
1082 | } |
1083 | } |
1084 | std::sort( |
1085 | output_nodes.begin(), output_nodes.end(), [&](Node* a, Node* b) { |
1086 | return a->isBefore(b); |
1087 | }); |
1088 | for (auto output_node : output_nodes) { |
1089 | if (auto group = tryMerge(producer, output_node)) { |
1090 | // we successfully merged, so the new group's `outputs` may have |
1091 | // changed. So rescan the new group for more merging opportunities. |
1092 | return std::make_pair(group.value()->iterator()++, true); |
1093 | } |
1094 | } |
1095 | } |
1096 | |
1097 | return std::make_pair(++producer->iterator(), false); |
1098 | } |
1099 | |
1100 | // Try to merge `consumer` into `producer`. If successful, this destroys |
1101 | // `consumer` and returns the `producer` group. |
1102 | c10::optional<Node*> tryMerge(Node* producer, Node* consumer) { |
1103 | AT_ASSERT(producer->kind() == prim::MKLDNNGroup); |
1104 | bool canMerge = shouldConsiderForMerge(consumer) && |
1105 | aliasDb_.moveAfterTopologicallyValid(consumer, producer); |
1106 | |
1107 | if (!canMerge) { |
1108 | return c10::nullopt; |
1109 | } |
1110 | |
1111 | SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing( |
1112 | consumer, producer, aliasDb_); |
1113 | |
1114 | return producer; |
1115 | } |
1116 | |
1117 | Block* block_; |
1118 | std::shared_ptr<Graph> graph_; |
1119 | AliasDb& aliasDb_; |
1120 | }; |
1121 | |
1122 | bool containsMKLDNNGroup(Block* b) { |
1123 | for (Node* n : b->nodes()) { |
1124 | for (Block* block : n->blocks()) { |
1125 | if (containsMKLDNNGroup(block)) { |
1126 | return true; |
1127 | } |
1128 | } |
1129 | if (MKLDNNSubgraphSlicer::MKLDNNGroupStart(n)) { |
1130 | return true; |
1131 | } |
1132 | } |
1133 | return false; |
1134 | } |
1135 | |
1136 | } // namespace |
1137 | |
1138 | void ConvertFrozenOpsToMKLDNN(std::shared_ptr<Graph>& graph) { |
1139 | GRAPH_DUMP("Before convert frozen ops to mkldnn" , graph); |
1140 | // TODO: replace conv1d with conv2d ? |
1141 | graph_rewrite_helper::replaceConvolutionWithAtenConv(graph); |
1142 | if (containsMKLDNNGroup(graph->block())) { |
1143 | // Only remove tensor mutation if we know we're going to create speedups |
1144 | // with mkldnn. Only supporting functional ops simplifies this pass bc |
1145 | // running an op in mkldnn removes the aliasing relationships that |
1146 | // previously existed between input and output. |
1147 | RemoveTensorMutation(graph, [](Node* node_to_functionalize) { |
1148 | static std::unordered_set<Symbol> mkldnn_ops = { |
1149 | aten::add_, |
1150 | aten::mul_, |
1151 | aten::relu_, |
1152 | aten::relu6_, |
1153 | aten::gelu_, |
1154 | aten::hardswish_, |
1155 | aten::dropout_, |
1156 | aten::sigmoid_, |
1157 | aten::hardsigmoid_, |
1158 | aten::hardtanh_, |
1159 | aten::tanh_, |
1160 | aten::clamp_, |
1161 | }; |
1162 | return mkldnn_ops.count(node_to_functionalize->kind()) != 0; |
1163 | }); |
1164 | |
1165 | AliasDb db(graph); |
1166 | MKLDNNSubgraphSlicer(graph->block(), graph, db).run(); |
1167 | EliminateDeadCode(graph); |
1168 | GRAPH_DUMP("After convert frozen ops to mkldnn" , graph); |
1169 | } else { |
1170 | GRAPH_DUMP("No mkldnn compatible frozen nodes" , graph); |
1171 | } |
1172 | } |
1173 | |
1174 | #else |
1175 | |
1176 | void ConvertFrozenOpsToMKLDNN(std::shared_ptr<Graph>& graph) { |
1177 | GRAPH_DUMP("MKLDNN Not enabled" , graph); |
1178 | } |
1179 | |
1180 | #endif |
1181 | |
1182 | } // namespace jit |
1183 | } // namespace torch |
1184 | |