1#include <ATen/ATen.h>
2#include <ATen/Config.h>
3#include <ATen/Utils.h>
4#include <ATen/core/symbol.h>
5#include <ATen/native/layer_norm.h>
6#include <c10/core/ScalarType.h>
7#include <c10/util/Exception.h>
8#include <c10/util/irange.h>
9
10#include <torch/csrc/jit/ir/alias_analysis.h>
11#include <torch/csrc/jit/ir/constants.h>
12#include <torch/csrc/jit/ir/ir.h>
13#include <torch/csrc/jit/jit_log.h>
14#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
15#include <torch/csrc/jit/passes/constant_propagation.h>
16#include <torch/csrc/jit/passes/dead_code_elimination.h>
17#include <torch/csrc/jit/passes/fold_conv_bn.h>
18#include <torch/csrc/jit/passes/frozen_conv_folding.h>
19#include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
20#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
21#include <torch/csrc/jit/passes/peephole.h>
22#include <torch/csrc/jit/passes/remove_mutation.h>
23#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
24#include <torch/csrc/jit/runtime/custom_operator.h>
25#include <torch/csrc/jit/runtime/operator_options.h>
26#include <torch/csrc/jit/tensorexpr/types.h>
27// clang-format off
28// moving ConvUtils include induces import cycle
29#include <ATen/native/ConvUtils.h>
30#include <algorithm>
31#include <memory>
32#include <ATen/core/stack.h>
33#include <c10/core/Layout.h>
34#include <c10/util/StringUtil.h>
35
36#if AT_MKLDNN_ENABLED()
37#include <ATen/CPUFunctions.h>
38#include <dnnl_types.h>
39#include <ATen/native/mkldnn/Utils.h>
40#include <ATen/native/mkldnn/MKLDNNCommon.h>
41#include <ideep.hpp>
42#endif
43
44// clang-format on
45
46namespace torch {
47namespace jit {
48
49#if AT_MKLDNN_ENABLED()
50
51using Tensor = at::Tensor;
52
53namespace {
54
55c10::AliasAnalysisKind aliasAnalysisFromSchema() {
56 return AliasAnalysisKind::FROM_SCHEMA;
57}
58
59using ValueSet = std::unordered_set<Value*>;
60using ValueSetPtr = std::shared_ptr<std::unordered_set<Value*>>;
61
62Node* getLastUse(Value* v) {
63 auto last_use_node = v->node();
64 for (const auto& use : v->uses()) {
65 if (use.user->isAfter(last_use_node)) {
66 last_use_node = use.user;
67 }
68 }
69 return last_use_node;
70}
71
72void merge_sets(
73 std::unordered_map<Value*, ValueSetPtr>& alias_mapping,
74 Value* existing,
75 Value* new_v) {
76 if (alias_mapping[existing] == alias_mapping[new_v]) {
77 return;
78 }
79 auto existing_set = alias_mapping[existing];
80 auto set_to_remove = alias_mapping[new_v];
81 for (auto it = set_to_remove->begin(); it != set_to_remove->end(); it++) {
82 existing_set->insert(*it);
83 alias_mapping[*it] = existing_set;
84 }
85}
86
87// no uses of tensors in container types
88void assertNonTensorTypeDoesNotContainTensors(TypePtr type) {
89 if (type->cast<TensorType>()) {
90 return;
91 }
92 for (const auto& t : type->containedTypes()) {
93 TORCH_INTERNAL_ASSERT(!t->cast<TensorType>());
94 }
95}
96
97void InplaceMKLDNNSubgraph(std::shared_ptr<Graph> graph) {
98 // This function first calculates aliasing sets,
99 // then calculates the last node each aliasing set is alive for.
100 // Then we go through each node, if it's a node which has an equivalent
101 // inplace node and the aliasing set for its input is dead afer this node, we
102 // inplace it. Then we merge the aliasing sets for the input and output of the
103 // node and extend the liveness of the set. To inplace a node you need to
104 // prove device and dtype of the input and output are the same, which we've
105 // already done, and prove that the output size is the same as the input size,
106 // which is achieved by explicit Broadcast nodes (which we inserted for other
107 // reasons).
108 // The graphs here are simple subgraphs without uses of Tensors in
109 // containers (Lists, GetAttrs, etc)
110
111 // CALCULATE ALIASING SETS
112
113 auto aliasDb = torch::make_unique<AliasDb>(graph);
114
115 // map from Value to its Aliasing Set
116 std::unordered_map<Value*, ValueSetPtr> alias_mapping;
117 ValueSet set;
118 ValueSetPtr input_set = std::make_shared<ValueSet>(set);
119 for (Value* v : graph->inputs()) {
120 if (v->type()->cast<TensorType>()) {
121 input_set->insert(v);
122 alias_mapping[v] = input_set;
123 } else {
124 assertNonTensorTypeDoesNotContainTensors(v->type());
125 }
126 }
127
128 for (Node* n : graph->nodes()) {
129 for (Value* output : n->outputs()) {
130 if (!output->type()->cast<TensorType>()) {
131 assertNonTensorTypeDoesNotContainTensors(output->type());
132 continue;
133 }
134
135 std::unordered_set<Value*> new_set = {output};
136 alias_mapping[output] = std::make_shared<ValueSet>(new_set);
137 for (Value* input : n->inputs()) {
138 if (aliasDb->mayAlias(input, output)) {
139 merge_sets(alias_mapping, input, output);
140 }
141 }
142 }
143 }
144
145 // CALCULATE ALIASING SET LIVENESS
146
147 // map from aliased set -> last use of set
148 std::unordered_map<ValueSetPtr, Node*> set_liveness;
149 for (auto& set : alias_mapping) {
150 if (set_liveness.count(set.second)) {
151 continue;
152 }
153 Node* last = nullptr;
154 for (const auto& v : *set.second) {
155 auto k = v->node()->kind();
156 if (k == prim::Constant || k == prim::ConstantMKLDNNTensor ||
157 k == prim::Param) {
158 last = graph->return_node();
159 continue;
160 }
161
162 auto last_use = getLastUse(v);
163 if (!last || last_use->isAfter(last)) {
164 last = last_use;
165 }
166 }
167 set_liveness[set.second] = last;
168 }
169
170 // REUSING MEMORY BY REINPLACING NODES
171 std::vector<Node*> nodes_to_inplace;
172
173 auto add_to_inplace_set = [&](Node* node) {
174 // defer making the inplacing change because that would invalidate the old
175 // Node output Value*
176 nodes_to_inplace.push_back(node);
177 TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
178 auto output_liveness_end =
179 set_liveness[alias_mapping[node->outputs().at(0)]];
180 merge_sets(alias_mapping, node->inputs().at(0), node->output());
181 set_liveness[alias_mapping[node->output()]] = output_liveness_end;
182 };
183
184 for (Node* node : graph->nodes()) {
185 auto k = node->kind();
186 if (k == aten::relu || k == aten::sigmoid || k == aten::dropout ||
187 k == prim::MKLDNNHardSwish || k == prim::MKLDNNHardSigmoid ||
188 k == prim::MKLDNNHardTanh || k == aten::tanh ||
189 k == prim::MKLDNNClamp || k == Symbol::prim("MKLDNNScalarMul") ||
190 k == Symbol::prim("MKLDNNLayerNorm")) {
191 if (set_liveness[alias_mapping[node->inputs().at(0)]]->isAfter(node)) {
192 continue;
193 }
194 add_to_inplace_set(node);
195 } else if (k == aten::mul || k == aten::add) {
196 // the binary operators (add/mul) are commutative and only take tensor
197 // inputs, so we can inplace either the first or second input
198 int64_t reusable_value_index = -1;
199 for (const auto i : c10::irange(2)) {
200 TORCH_INTERNAL_ASSERT(node->inputs().at(i)->type()->cast<TensorType>());
201 if (!set_liveness[alias_mapping[node->inputs().at(i)]]->isAfter(node)) {
202 reusable_value_index = i;
203 break;
204 }
205 }
206
207 if (reusable_value_index == -1) {
208 continue;
209 }
210
211 if (reusable_value_index == 1) {
212 node->insertInput(0, node->inputs().at(1));
213 node->removeInput(2);
214 }
215 add_to_inplace_set(node);
216 }
217 }
218
219 for (Node* node : nodes_to_inplace) {
220 node->replaceWithNewSymbol(
221 Symbol::fromQualString(node->schema().name() + "_"));
222 node->destroy();
223 }
224}
225
226// This is a factory function that creates an Operation that that takes
227// MKLDNN tensors and unpacks them into 1D contiguous tensors that we can
228// run aten operations on. The precondition for using this function is that the
229// aten operations in `aten_op` should be an identity for zero inputs. In other
230// words, this should: `aten_op(0) = 0` The reason for this precondition has to
231// do with blocked formats MKLDNN uses to lay tensor elements (nChw8c, nChw16c,
232// etc). It splits the channel dimension into chunks of 8/16 makes it the
233// innermost dimension. Whenever the channel dim isn't divisible by 8/16 the
234// innermost dimension is padded with 0s. The precondition, `aten_op(0) == 0`
235// allows us to avoid any special casing of padded elements.
236Operation createUnaryOp(
237 std::function<void(at::Tensor output, at::Tensor input)> aten_op,
238 bool inplace = false) {
239 return [aten_op, inplace](Stack& stack) {
240 auto a = pop(stack).toTensor();
241 c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
242 // we cast `a` to an `ideep::tensor`, so we can get at its descriptor
243 // which we then use to set up `out` tensor w/ the same props as a
244 auto a_it = at::native::itensor_from_mkldnn(a);
245 auto mkldnn_raw_data = a_it.get_data_handle();
246 auto a_options_with_strided = a.options().layout(c10::kStrided);
247
248 // tensor's physical size could be bigger than a logical one
249 // `a_it.get_desc().get_size()` returns the real physical size (in bytes)
250 // we use it to compute `nelem` for `aten` ops
251 auto nelem = static_cast<int64_t>(
252 a_it.get_desc().get_size() / elementSize(a.scalar_type()));
253 // we also wrap `a` storage into an aten tensor
254 auto in_aten =
255 at::from_blob(mkldnn_raw_data, {nelem}, a_options_with_strided);
256
257 auto out_raw_data = mkldnn_raw_data;
258 auto out = a;
259 if (!inplace) {
260 // `a_it.get_desc()` will allocate a tensor
261 // of the right physical size.
262 auto it_empty = ideep::tensor(a_it.get_desc());
263 TORCH_INTERNAL_ASSERT(it_empty.get_desc() == a_it.get_desc());
264 out = at::native::new_with_itensor_mkldnn(
265 std::move(it_empty),
266 optTypeMetaToScalarType(a.options().dtype_opt()),
267 a.options().device_opt());
268
269 out_raw_data = at::native::itensor_from_mkldnn(out).get_data_handle();
270 }
271
272 TORCH_INTERNAL_ASSERT(
273 a_it.get_desc().get_size() % elementSize(a.scalar_type()) == 0);
274
275 auto out_aten = at::from_blob(
276 out_raw_data, {static_cast<int64_t>(nelem)}, a_options_with_strided);
277 aten_op(out_aten, in_aten);
278 push(stack, out);
279 };
280}
281
282void MKLDNNLayerNormOp(Stack& stack, bool inplace) {
283 c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
284
285 // enable_cudnn not used
286 pop(stack);
287 auto eps = pop(stack).toDouble();
288
289 Tensor bias{};
290 Tensor weight{};
291 auto bias_ival = pop(stack);
292 TORCH_INTERNAL_ASSERT(bias_ival.isTensor());
293 bias = bias_ival.toTensor();
294
295 auto weight_ival = pop(stack);
296 TORCH_INTERNAL_ASSERT(weight_ival.isTensor());
297 weight = weight_ival.toTensor();
298
299 auto shape = pop(stack).toDimVector();
300 auto input = pop(stack).toTensor();
301
302 at::Tensor dst, mean, rstd;
303 std::tie(dst, mean, rstd) =
304 at::native::mkldnn_layer_norm_last_index_weight_bias_f32(
305 input, shape, weight, bias, eps, inplace);
306 push(stack, dst);
307};
308
309Operation BroadOp(const Node* node) {
310 return [](Stack& stack) {
311 auto b = pop(stack).toTensor();
312 auto a = pop(stack).toTensor();
313 auto b_size = b.sizes();
314 auto a_size = a.sizes();
315 if (a_size.equals(b_size)) {
316 // TODO: follow up with MKLDNN what the best way is
317 // to handle perf incompatible formats
318 push(stack, a, b);
319 return;
320 } else {
321 auto out_size = at::infer_size(a_size, b_size);
322 int64_t out_numel = out_size[0];
323 for (size_t i = 1, end = out_size.size(); i < end; ++i) {
324 out_numel = out_numel * out_size[i];
325 }
326
327 auto exp_a = a;
328 auto exp_b = b;
329 int stacked = 0;
330 // mkldnn tensors only support reshape, not expand or view operators
331 if (a_size.equals(out_size)) {
332 push(stack, a);
333 ++stacked;
334 } else if (out_numel == a.numel()) {
335 exp_a = a.reshape(out_size);
336 } else {
337 // TODO: consider to initializing to a blocked layout
338 // directly if needed
339 exp_a = a.to_dense().expand(out_size).to_mkldnn();
340 }
341
342 if (b_size.equals(out_size)) {
343 push(stack, b);
344 ++stacked;
345 } else if (out_numel == b.numel()) {
346 exp_b = b.reshape(out_size);
347 } else {
348 exp_b = b.to_dense().expand(out_size).to_mkldnn();
349 }
350
351 if (stacked < 2) {
352 if (stacked == 1) {
353 pop(stack);
354 }
355 // If one of the inputs was expanded and converted to nchw/nhwc
356 // we might end up in a very bad spot if the second argument
357 // is in a blocked format. In this case, MKLDNN uses its
358 // reference implementation for a binary operation that follows
359 // these broadcasts and it could be up to ~100x slower.
360 // We use a very simple heuristic to convert an arg in nchw
361 // to the blocked format of the other argument.
362 c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
363 auto a_it = at::native::itensor_from_mkldnn(exp_a);
364 auto b_it = at::native::itensor_from_mkldnn(exp_b);
365
366 // `is_public_format` means a tensor's physical layout isn't in MKLDNN
367 // blocked layout e.g. nchw or nhwc but not nChw8c
368 if (!a_it.is_public_format()) {
369 if (b_it.is_public_format()) {
370 b_it = b_it.reorder_if_differ_in(a_it.get_desc());
371 }
372 } else if (!b_it.is_public_format()) {
373 if (a_it.is_public_format()) {
374 a_it = a_it.reorder_if_differ_in(b_it.get_desc());
375 }
376 }
377
378 auto a_options = exp_a.options();
379 auto a_out = at::native::new_with_itensor_mkldnn(
380 std::move(a_it),
381 optTypeMetaToScalarType(a_options.dtype_opt()),
382 a_options.device_opt());
383 push(stack, a_out);
384 auto b_options = exp_b.options();
385 auto b_out = at::native::new_with_itensor_mkldnn(
386 std::move(b_it),
387 optTypeMetaToScalarType(b_options.dtype_opt()),
388 b_options.device_opt());
389 push(stack, b_out);
390 };
391 }
392 };
393}
394
395static std::function<void(at::Tensor output, at::Tensor input)> hardtanh_helper(
396 const Node* n) {
397 auto min_val = n->f(attr::min_val);
398 auto max_val = n->f(attr::max_val);
399 return [min_val, max_val](at::Tensor output, at::Tensor input) {
400 at::cpu::hardtanh_out(output, input, min_val, max_val);
401 };
402}
403
404static std::function<void(at::Tensor output, at::Tensor input)> clamp_helper(
405 const Node* n) {
406 auto min_val = n->f(attr::min_val);
407 auto max_val = n->f(attr::max_val);
408 return [min_val, max_val](at::Tensor output, at::Tensor input) {
409 at::cpu::clamp_out(output, input, min_val, max_val);
410 };
411}
412
413// any op added to this registry needs to meet
414// the precondition: `aten_op(0) == 0`
415const RegisterOperators MKLDNNHardSwishOpReg({
416 torch::jit::Operator(
417 "prim::MKLDNNHardSwish_(Tensor(a!) self) -> Tensor(a!)",
418 createUnaryOp(
419 [](at::Tensor output, at::Tensor input) {
420 at::cpu::hardswish_out(output, input);
421 },
422 true),
423 AliasAnalysisKind::FROM_SCHEMA),
424 torch::jit::Operator(
425 "prim::MKLDNNHardSigmoid_(Tensor(a!) self) -> Tensor(a!)",
426 createUnaryOp(
427 [](at::Tensor output, at::Tensor input) {
428 at::cpu::hardsigmoid_out(output, input);
429 },
430 true),
431 AliasAnalysisKind::FROM_SCHEMA),
432 torch::jit::Operator(
433 "prim::MKLDNNHardTanh_(Tensor(a!) self) -> Tensor(a!)",
434 [](const Node* n) -> Operation {
435 return createUnaryOp(hardtanh_helper(n), true);
436 },
437 AliasAnalysisKind::FROM_SCHEMA),
438 torch::jit::Operator(
439 "prim::MKLDNNClamp_(Tensor(a!) self) -> Tensor(a!)",
440 [](const Node* n) -> Operation {
441 return createUnaryOp(clamp_helper(n), true);
442 },
443 AliasAnalysisKind::FROM_SCHEMA),
444 torch::jit::Operator(
445 "prim::MKLDNNHardSwish(Tensor a) -> Tensor",
446 createUnaryOp(
447 [](at::Tensor output, at::Tensor input) {
448 at::cpu::hardswish_out(output, input);
449 },
450 false),
451 AliasAnalysisKind::FROM_SCHEMA),
452 torch::jit::Operator(
453 "prim::MKLDNNHardSigmoid(Tensor a) -> Tensor",
454 createUnaryOp(
455 [](at::Tensor output, at::Tensor input) {
456 at::cpu::hardsigmoid_out(output, input);
457 },
458 false),
459 AliasAnalysisKind::FROM_SCHEMA),
460 torch::jit::Operator(
461 "prim::MKLDNNHardTanh(Tensor self) -> Tensor",
462 [](const Node* n) -> Operation {
463 return createUnaryOp(hardtanh_helper(n), false);
464 },
465 AliasAnalysisKind::FROM_SCHEMA),
466 torch::jit::Operator(
467 "prim::MKLDNNClamp(Tensor self) -> Tensor",
468 [](const Node* n) -> Operation {
469 return createUnaryOp(clamp_helper(n), false);
470 },
471 AliasAnalysisKind::FROM_SCHEMA),
472});
473
474const RegisterOperators BroadOpReg({
475 torch::jit::Operator(
476 prim::BroadcastMKLDNNTensors,
477 BroadOp,
478 AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
479});
480
481const RegisterOperators MKLDNNLayerNormOpReg({
482 torch::jit::Operator(
483 "prim::MKLDNNLayerNorm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor",
484 [](Stack& stack) { MKLDNNLayerNormOp(stack, false); },
485 AliasAnalysisKind::FROM_SCHEMA),
486 torch::jit::Operator(
487 "prim::MKLDNNLayerNorm_(Tensor(a!) input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor(a!)",
488 [](Stack& stack) { MKLDNNLayerNormOp(stack, true); },
489 AliasAnalysisKind::FROM_SCHEMA),
490});
491
492Operation ConstantMKLDNNTensorOp(const Node* node) {
493 const auto& t = node->t(attr::value);
494 return [t](Stack& stack) {
495 push(stack, t);
496 return 0;
497 };
498}
499
500Tensor mkldnn_tensor_scalar_mul(Tensor& tensor, Tensor& out, float scalar) {
501 ideep::tensor& x = at::native::itensor_from_mkldnn(tensor);
502 ideep::tensor& z = at::native::itensor_from_mkldnn(out);
503 ideep::eltwise_forward::compute(
504 x,
505 z,
506 ideep::algorithm::eltwise_linear,
507 ideep::prop_kind::forward_inference,
508 /*alpha*/ scalar);
509 return out;
510}
511
512// aten::convolution does a lot of precomputation and dispatching before
513// mkldnn_convolution is called. registering here we can directly invoke the op
514// and avoid overhead. avoiding dispatch overhead for other operators - relu,
515// add, etc - did not benchmark as speeding up models noticeably. the additional
516// overhead of `convolution` warrants the custom operator.
517jit::RegisterOperators reg_fut_ops({
518 jit::Operator(
519 // XXX: this follows the schema convention of conv2d/conv3d, not
520 // aten::mkldnn_convolution, which is different for some reason!
521 "prim::mkldnn_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
522 [](jit::Stack& stack) {
523 int64_t groups = pop(stack).toInt();
524 auto dilation = pop(stack).toIntVector();
525 auto padding = pop(stack).toIntVector();
526 auto stride = pop(stack).toIntVector();
527
528 Tensor bias;
529 IValue bias_ival = pop(stack);
530 if (!bias_ival.isNone()) {
531 bias = bias_ival.toTensor();
532 }
533 Tensor weight = pop(stack).toTensor();
534 Tensor input = pop(stack).toTensor();
535
536 at::AutoDispatchBelowAutograd mode;
537 // aten::convolution takes care of 0 dim case before calls into
538 // backends
539 if (input.size(0) == 0) {
540 std::vector<int64_t> o = at::native::conv_output_size(
541 input.sizes(), weight.sizes(), padding, stride, dilation);
542 push(
543 stack,
544 at::native::empty_mkldnn(
545 o,
546 optTypeMetaToScalarType(input.options().dtype_opt()),
547 input.options().layout_opt(),
548 input.options().device_opt(),
549 input.options().pinned_memory_opt()));
550 return;
551 }
552 // aten::convolution also checks dtype mismatches
553 TORCH_CHECK(
554 input.options().type_equal(weight.options()),
555 "Input type (",
556 input.toString(),
557 ") and weight type (",
558 weight.toString(),
559 ") should be the same");
560
561 push(
562 stack,
563 at::native::mkldnn_convolution(
564 input, weight, bias, padding, stride, dilation, groups));
565 },
566 aliasAnalysisFromSchema()),
567 // registering as custom operators avoids Scalar->Tensor->Scalar conversion
568 // in default bindings
569 jit::Operator(
570 "prim::MKLDNNScalarMul(Tensor self, Scalar other) -> Tensor",
571 [](jit::Stack& stack) {
572 c10::impl::ExcludeDispatchKeyGuard edkg(
573 c10::autograd_dispatch_keyset);
574 float other = pop(stack).toScalar().toFloat();
575 Tensor self = pop(stack).toTensor();
576 auto out = at::native::empty_mkldnn(
577 self.sizes(),
578 optTypeMetaToScalarType(self.options().dtype_opt()),
579 self.options().layout_opt(),
580 self.options().device_opt(),
581 self.options().pinned_memory_opt());
582
583 mkldnn_tensor_scalar_mul(self, out, other);
584 push(stack, out);
585 },
586 aliasAnalysisFromSchema()),
587 jit::Operator(
588 "prim::MKLDNNScalarMul_(Tensor(a!) self, Scalar other) -> Tensor(a!)",
589 [](jit::Stack& stack) {
590 c10::impl::ExcludeDispatchKeyGuard edkg(
591 c10::autograd_dispatch_keyset);
592 float other = pop(stack).toScalar().toFloat();
593 Tensor self = pop(stack).toTensor();
594 mkldnn_tensor_scalar_mul(self, self, other);
595 push(stack, self);
596 },
597 aliasAnalysisFromSchema()),
598});
599
600// This is registered as its own op instead of as prim::Constant bc it does not
601// serialize which is an invariant of prim::Constant
602// TODO: make mkldnn tensor serialize...
603const RegisterOperators MKLDNNConstantOp({
604 torch::jit::Operator(
605 prim::ConstantMKLDNNTensor,
606 ConstantMKLDNNTensorOp,
607 AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
608});
609
610Node* createConstantMKLDNNTensorOp(Graph* g, const Tensor& mkldnn_tensor) {
611 TORCH_INTERNAL_ASSERT(mkldnn_tensor.is_mkldnn());
612 auto op = g->create(prim::ConstantMKLDNNTensor);
613 op->t_(attr::value, mkldnn_tensor);
614 return op;
615}
616
617bool supportedMKLDNNWeight(const Tensor& weight) {
618 return weight.device().is_cpu() && weight.dtype() == c10::ScalarType::Float &&
619 weight.ndimension() != 0;
620}
621
622void replaceInputWithMKLDNNTensor(Node* n, size_t index) {
623 Value* input = n->inputs().at(index);
624 auto mkldnn_tensor = constant_as<Tensor>(input)->to_mkldnn();
625 auto mkldnn_tensor_value =
626 createConstantMKLDNNTensorOp(n->owningGraph(), mkldnn_tensor)
627 ->insertBefore(n)
628 ->output();
629 mkldnn_tensor_value->setDebugName(input->debugName() + "_mkldnn");
630 n->replaceInputWith(input, mkldnn_tensor_value);
631}
632
633void replaceInputWithMKLDNNTensor(
634 Node* n,
635 const std::string& name,
636 const at::Tensor& mkldnn_tensor) {
637 Value* input = n->namedInput(name);
638 auto mkldnn_tensor_value =
639 createConstantMKLDNNTensorOp(n->owningGraph(), mkldnn_tensor)
640 ->insertBefore(n)
641 ->output();
642 mkldnn_tensor_value->setDebugName(input->debugName() + "_mkldnn");
643 n->replaceInputWith(input, mkldnn_tensor_value);
644}
645
646void replaceInputWithMKLDNNTensor(Node* n, const std::string& name) {
647 Value* input = n->namedInput(name);
648 auto mkldnn_tensor = constant_as<Tensor>(input)->to_mkldnn();
649 replaceInputWithMKLDNNTensor(n, name, mkldnn_tensor);
650}
651
652void moveConvWeightsToMKLDNN(Node* conv) {
653 auto conv_w_mkldnn =
654 constant_as<Tensor>(conv->namedInput("weight")).value().to_mkldnn();
655 std::vector<int64_t> padding =
656 toIValue(conv->namedInput("padding"))->toIntVector();
657 std::vector<int64_t> stride =
658 toIValue(conv->namedInput("stride"))->toIntVector();
659 std::vector<int64_t> dilation =
660 toIValue(conv->namedInput("dilation"))->toIntVector();
661 auto groups = constant_as<int64_t>(conv->namedInput("groups")).value();
662
663 if (conv->kind() == aten::conv2d) {
664 conv_w_mkldnn = mkldnn_reorder_conv2d_weight(
665 conv_w_mkldnn, padding, stride, dilation, groups);
666 } else if (conv->kind() == aten::conv3d) {
667 conv_w_mkldnn = mkldnn_reorder_conv3d_weight(
668 conv_w_mkldnn, padding, stride, dilation, groups);
669 } else {
670 TORCH_INTERNAL_ASSERT(false);
671 }
672 replaceInputWithMKLDNNTensor(conv, "weight", conv_w_mkldnn);
673
674 if (conv->namedInput("bias")->type() != NoneType::get()) {
675 replaceInputWithMKLDNNTensor(conv, "bias");
676 }
677}
678
679void moveWeightsToMKLDNN(Node* n) {
680 // conv goes through special pathway so we can call mkldnn reorder conv
681 // primitive
682 if (n->kind() == aten::conv2d || n->kind() == aten::conv3d) {
683 moveConvWeightsToMKLDNN(n);
684 } else {
685 for (size_t i = 0; i < n->inputs().size(); ++i) {
686 if (!n->input(i)->type()->cast<TensorType>() ||
687 n->input(i)->node()->kind() != prim::Constant) {
688 continue;
689 }
690 replaceInputWithMKLDNNTensor(n, i);
691 }
692 }
693}
694
695static void clamp_node_creator(
696 Node* body_node,
697 c10::Symbol kind,
698 double min_val,
699 double max_val) {
700 WithInsertPoint insert_guard{body_node};
701 auto out_node =
702 body_node->owningGraph()->create({kind}, {body_node->input(0)}, 1);
703 // N.B. we can't use `insert` as it calls `getOperation` (via
704 // `emitBuiltinCall`) which uses `min_val` and `max_val` attrs which we
705 // haven't set yet.
706 body_node->owningGraph()->insertNode(out_node);
707 auto out_val = out_node->output();
708 out_node->f_(attr::min_val, min_val);
709 out_node->f_(attr::max_val, max_val);
710 out_val->copyMetadata(body_node->output());
711 body_node->output()->replaceAllUsesWith(out_val);
712 body_node->destroy();
713}
714
715void ComputeSubgraphInMKLDNN(Node* subgraph_node) {
716 auto graph = subgraph_node->owningGraph();
717 Value* none_value = nullptr;
718 {
719 WithInsertPoint guard(subgraph_node);
720 none_value = graph->insertConstant(IValue());
721 }
722 for (size_t i = 0; i < subgraph_node->inputs().size(); ++i) {
723 Value* v = subgraph_node->inputs().at(i);
724 if (!v->type()->cast<TensorType>()) {
725 continue;
726 }
727 auto to_mkldnn =
728 graph->create(c10::Symbol::fromQualString("aten::to_mkldnn"), 1)
729 ->insertBefore(subgraph_node);
730 to_mkldnn->addInput(v);
731 to_mkldnn->addInput(none_value);
732 subgraph_node->replaceInput(i, to_mkldnn->output());
733 }
734
735 for (size_t i = 0; i < subgraph_node->outputs().size(); ++i) {
736 Value* v = subgraph_node->outputs().at(i);
737 if (!v->type()->cast<TensorType>()) {
738 continue;
739 }
740 auto from_mkldnn =
741 graph
742 ->create(
743 c10::Symbol::fromQualString("aten::to_dense"), {v, none_value})
744 ->insertAfter(subgraph_node);
745 v->replaceAllUsesAfterNodeWith(from_mkldnn, from_mkldnn->output());
746 }
747
748 auto subgraph = SubgraphUtils::getSubgraph(subgraph_node);
749 for (auto it = subgraph->block()->nodes().begin();
750 it != subgraph->block()->nodes().end();) {
751 Node* body_node = *it;
752 it++;
753
754 moveWeightsToMKLDNN(body_node);
755
756 if (body_node->kind() == aten::add ||
757 (body_node->kind() == aten::mul &&
758 body_node->input(1)->type()->cast<TensorType>())) {
759 auto node = body_node->owningGraph()->create(
760 Symbol::prim("BroadcastMKLDNNTensors"),
761 {body_node->inputs().at(0), body_node->inputs().at(1)},
762 2);
763 node->insertBefore(body_node);
764 body_node->replaceInput(0, node->outputs().at(0));
765 body_node->replaceInput(1, node->outputs().at(1));
766 }
767 if (body_node->kind() == aten::mul &&
768 body_node->input(1)->type()->isSubtypeOf(*NumberType::get())) {
769 body_node->replaceWithNewSymbol(Symbol::prim("MKLDNNScalarMul"));
770 body_node->destroy();
771 continue;
772 }
773
774 if (body_node->matches(
775 "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor")) {
776 body_node->replaceWithNewSymbol(Symbol::prim("MKLDNNLayerNorm"));
777 body_node->destroy();
778 continue;
779 }
780
781 if (body_node->kind() == aten::hardswish) {
782 body_node->replaceWithNewSymbol(prim::MKLDNNHardSwish);
783 body_node->destroy();
784 continue;
785 }
786
787 if (body_node->kind() == aten::hardsigmoid) {
788 body_node->replaceWithNewSymbol(prim::MKLDNNHardSigmoid);
789 body_node->destroy();
790 continue;
791 }
792
793 if (body_node->kind() == aten::relu6) {
794 clamp_node_creator(body_node, prim::MKLDNNHardTanh, 0., 6.);
795 continue;
796 }
797
798 if (body_node->kind() == aten::hardtanh) {
799 auto min_val =
800 constant_as<double>(body_node->namedInput("min_val")).value();
801 auto max_val =
802 constant_as<double>(body_node->namedInput("max_val")).value();
803 clamp_node_creator(body_node, prim::MKLDNNHardTanh, min_val, max_val);
804 continue;
805 }
806
807 if (body_node->kind() == aten::clamp) {
808 auto min_val = constant_as<double>(body_node->namedInput("min")).value();
809 auto max_val = constant_as<double>(body_node->namedInput("max")).value();
810 clamp_node_creator(body_node, prim::MKLDNNClamp, min_val, max_val);
811 continue;
812 }
813
814 if (body_node->kind() == aten::conv2d ||
815 body_node->kind() == aten::conv3d) {
816 // this node doesnt handle string padding yet...
817 if (!body_node->namedInput("padding")->type()->cast<StringType>()) {
818 body_node->replaceWithNewSymbol(Symbol::prim("mkldnn_convolution"));
819 body_node->destroy();
820 continue;
821 }
822 }
823 }
824}
825
826bool nonConstantParameters(Node* n) {
827 for (size_t i = 1; i < n->inputs().size(); i++) {
828 if (n->inputs().at(i)->node()->kind() != prim::Constant) {
829 return true;
830 }
831 }
832 return false;
833}
834
835bool frozenMkldnnCompatibleLinearNode(Node* n) {
836 if (nonConstantParameters(n)) {
837 return false;
838 }
839
840 if (n->kind() != aten::linear) {
841 return false;
842 }
843
844 auto weight = constant_as<Tensor>(n->namedInput("weight")).value();
845 return supportedMKLDNNWeight(weight);
846}
847
848bool frozenMkldnnCompatibleConvNode(Node* n) {
849 if (nonConstantParameters(n)) {
850 return false;
851 }
852 // mkldnn does not support conv1d
853 // _convolution is rewritten before this pass is invoked
854 if (n->kind() != aten::conv2d && n->kind() != aten::conv3d) {
855 return false;
856 }
857
858 auto weight = constant_as<Tensor>(n->namedInput("weight")).value();
859 return supportedMKLDNNWeight(weight);
860}
861
862// [mkldnn perf strategy]
863// Certain ops - aten::linear, aten::conv2d, aten::conv3d - provide a huge speed
864// up just by converting the constant weights to MKLDNN AOT, and then at runtime
865// converting the non-constant input to_mkldnn before the op, and then back to
866// its original layout after the op. The speed up holds even if you end up
867// converting the input to_mkldnn and output back to_dense. We start groups of
868// ops to compute in MKLDNN only from these ops that are a strict speedup. Then,
869// we expand the groups to include operators which are computable in MKLDNN &
870// are roughly perf equal to eager. We do this in the hopes of joining multiple
871// fast nodes together, saving to_mkldnn and to_dense conversions.
872//
873// MKLDNN only supports float32 inputs for aten::linear, aten::conv2d &
874// aten::conv3d. We only fuse these nodes if the weights are float32, and then
875// we only include operators which we can prove will execute in float32. By
876// fusing topologically we can maintain the invariant that all tensor types in
877// the graph are floating point. In fusing Conv-> Add -> Relu -> Conv we start
878// with the first Conv, know that the output is float, and can then safely merge
879// Add and Relu. If we started with the last Conv, it would be difficult to
880// prove in our first pass that the Add's inputs were both float32 without first
881// fusing the first conv.
882
883class MKLDNNSubgraphSlicer {
884 public:
885 MKLDNNSubgraphSlicer(
886 Block* block,
887 std::shared_ptr<Graph> graph,
888 AliasDb& aliasDb)
889 : block_(block), graph_(std::move(graph)), aliasDb_(aliasDb) {}
890
891 void run() {
892 // We maintain alias db correctness in-place while building up the autodiff
893 // subgraphs, however it is difficult to preserve correctness when
894 // un-inlining autodiff subgraphs. We first recursively construct all
895 // subgraphs and then unmerge them into the graph
896 buildupSubgraphs();
897 computeSubgraphsInMKLDNN();
898 // Run CSE globally onceto eliminate duplicates that may have occurred
899 // while inlining subgraphs.
900 EliminateCommonSubexpression(graph_);
901 }
902
903 void buildupSubgraphs() {
904 // We need to run the slicer multiple times in order to get all merge
905 // opportunities. This is because moveBeforeTopologicalValid may reorder
906 // nodes to be AFTER the current iteration point. In order to properly
907 // consider those nodes for merging, we need run the pass until no changes
908 // have been made.
909 //
910 // Example:
911 // c = f(a, b)
912 // d = f(c)
913 // e = f(d) <- iter is here, moving upward
914 // After c.moveBeforeTopologicallyValid(e), we have:
915 // c = f(a, b)
916 // e = f(d) <- iter still here
917 // d = f(c) <- this was node moved on the other side.
918
919 bool any_changed = true;
920 while (any_changed) {
921 any_changed = false;
922 for (auto it = block_->nodes().begin(); it != block_->nodes().end();) {
923 bool changed = false;
924 std::tie(it, changed) = scanNode(*it);
925 any_changed |= changed;
926 }
927 }
928
929 // Construct Subgraphs Recursively
930 for (Node* n : block_->nodes()) {
931 for (auto subBlock : n->blocks()) {
932 MKLDNNSubgraphSlicer(subBlock, graph_, aliasDb_).buildupSubgraphs();
933 }
934 }
935 }
936
937 static bool MKLDNNGroupStart(Node* node) {
938 // if we're already in the process of merging
939 if (node->kind() == prim::MKLDNNGroup) {
940 return true;
941 }
942 // see [mkldnn perf strategy]
943 return frozenMkldnnCompatibleConvNode(node);
944 }
945
946 private:
947 // MKLDNN only supports floats of dimension > 0, so we only support
948 // Tensors who have a known type or were previously verified
949 // to be usable in an MKLDNN Group
950 bool tensorInputIsMKLDNNSupported(Value* v, Node* v_use) {
951 auto const_tensor = constant_as<Tensor>(v);
952 if (const_tensor) {
953 return supportedMKLDNNWeight(*const_tensor);
954 }
955 auto k = v->node()->kind();
956 if (k == prim::MKLDNNGroup || k == prim::ConstantMKLDNNTensor ||
957 k == aten::to_mkldnn) {
958 return true;
959 }
960 for (const auto& use : v->uses()) {
961 if (use.user->kind() == aten::to_mkldnn &&
962 v_use->owningBlock() == use.user->owningBlock()) {
963 return true;
964 }
965 }
966 return false;
967 }
968
969 // We include ops here which are roughly perf-equivalent in mkldnn as with
970 // aten (single & multithreaded) and whose inputs & outputs are float32.
971 bool computableInMKLDNN(Node* n) {
972 for (Value* v : n->inputs()) {
973 if (v->type()->cast<TensorType>() &&
974 !(tensorInputIsMKLDNNSupported(v, n))) {
975 return false;
976 }
977 }
978
979 if (n->matches(
980 "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor") &&
981 n->namedInput("weight")->type() != NoneType::get() &&
982 n->namedInput("bias")->type() != NoneType::get()) {
983 auto norm_shape =
984 constant_as<std::vector<int64_t>>(n->namedInput("normalized_shape"));
985 return norm_shape.has_value() && norm_shape->size() == 1;
986 }
987
988 // unary ops we dont need to prove anything else than
989 // the input is mkldnn supported
990 switch (n->kind()) {
991 case aten::relu:
992 case aten::relu6:
993 case aten::gelu:
994 case aten::prelu:
995 case aten::sigmoid:
996 case aten::hardsigmoid:
997 case aten::hardswish:
998 case aten::tanh:
999 case aten::batch_norm:
1000 case aten::max_pool2d:
1001 case aten::max_pool3d:
1002 case aten::avg_pool2d:
1003 case aten::adaptive_avg_pool2d:
1004 case aten::avg_pool3d:
1005 // case aten::adaptive_max_pool2d: // return tuples which break fusion
1006 // case aten::adaptive_max_pool3d: // return tuples which break fusion
1007 // case aten::adaptive_avg_pool3d: // no ideep binding
1008 return true;
1009 }
1010
1011 if ((n->kind() == aten::hardtanh || n->kind() == aten::clamp) &&
1012 !nonConstantParameters(n)) {
1013 const size_t MIN_INDEX = 1, MAX_INDEX = 2;
1014 auto min_val = constant_as<double>(n->input(MIN_INDEX)).value();
1015 auto max_val = constant_as<double>(n->input(MAX_INDEX)).value();
1016 // we need to maintain the following invariant `pointwise_func(0) == 0`,
1017 // see `createUnaryOp`
1018 if (min_val <= 0. && max_val >= 0.) {
1019 return true;
1020 }
1021 }
1022
1023 if (n->kind() == aten::add) {
1024 // mkldnn doesn't currently support Tensor-Scalar add
1025 for (const auto i : c10::irange(2)) {
1026 if (!n->inputs().at(i)->type()->cast<TensorType>()) {
1027 return false;
1028 }
1029 }
1030 return true;
1031 }
1032 if (n->kind() == aten::mul) {
1033 return n->input(0)->type()->cast<TensorType>() &&
1034 (n->input(1)->type()->cast<TensorType>() ||
1035 n->input(1)->type()->isSubtypeOf(*NumberType::get()));
1036 }
1037
1038 if (n->kind() == aten::dropout) {
1039 auto train = constant_as<bool>(n->namedInput("train")).value();
1040 return train == false;
1041 }
1042 return false;
1043 }
1044
1045 void computeSubgraphsInMKLDNN() {
1046 auto curNode = *block_->nodes().begin();
1047 while (curNode != *block_->nodes().end()) {
1048 auto nextNode = curNode->next();
1049 if (curNode->kind() == prim::MKLDNNGroup) {
1050 ComputeSubgraphInMKLDNN(curNode);
1051 InplaceMKLDNNSubgraph(SubgraphUtils::getSubgraph(curNode));
1052 SubgraphUtils::unmergeSubgraph(curNode);
1053 }
1054 curNode = nextNode;
1055 }
1056 for (Node* n : block_->nodes()) {
1057 for (Block* b : n->blocks()) {
1058 MKLDNNSubgraphSlicer(b, graph_, aliasDb_).computeSubgraphsInMKLDNN();
1059 }
1060 }
1061 }
1062
1063 bool shouldConsiderForMerge(Node* node) {
1064 // if we're already in the process of merging
1065 if (node->kind() == prim::MKLDNNGroup) {
1066 return true;
1067 }
1068 return frozenMkldnnCompatibleLinearNode(node) ||
1069 frozenMkldnnCompatibleConvNode(node) || computableInMKLDNN(node);
1070 }
1071
1072 std::pair<graph_node_list::iterator, bool> scanNode(Node* producer) {
1073 if (MKLDNNGroupStart(producer)) {
1074 if (producer->kind() != prim::MKLDNNGroup) {
1075 producer = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
1076 producer, prim::MKLDNNGroup, aliasDb_);
1077 }
1078 std::vector<Node*> output_nodes;
1079 for (Value* v : producer->outputs()) {
1080 for (const auto& use : v->uses()) {
1081 output_nodes.push_back(use.user);
1082 }
1083 }
1084 std::sort(
1085 output_nodes.begin(), output_nodes.end(), [&](Node* a, Node* b) {
1086 return a->isBefore(b);
1087 });
1088 for (auto output_node : output_nodes) {
1089 if (auto group = tryMerge(producer, output_node)) {
1090 // we successfully merged, so the new group's `outputs` may have
1091 // changed. So rescan the new group for more merging opportunities.
1092 return std::make_pair(group.value()->iterator()++, true);
1093 }
1094 }
1095 }
1096
1097 return std::make_pair(++producer->iterator(), false);
1098 }
1099
1100 // Try to merge `consumer` into `producer`. If successful, this destroys
1101 // `consumer` and returns the `producer` group.
1102 c10::optional<Node*> tryMerge(Node* producer, Node* consumer) {
1103 AT_ASSERT(producer->kind() == prim::MKLDNNGroup);
1104 bool canMerge = shouldConsiderForMerge(consumer) &&
1105 aliasDb_.moveAfterTopologicallyValid(consumer, producer);
1106
1107 if (!canMerge) {
1108 return c10::nullopt;
1109 }
1110
1111 SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
1112 consumer, producer, aliasDb_);
1113
1114 return producer;
1115 }
1116
1117 Block* block_;
1118 std::shared_ptr<Graph> graph_;
1119 AliasDb& aliasDb_;
1120};
1121
1122bool containsMKLDNNGroup(Block* b) {
1123 for (Node* n : b->nodes()) {
1124 for (Block* block : n->blocks()) {
1125 if (containsMKLDNNGroup(block)) {
1126 return true;
1127 }
1128 }
1129 if (MKLDNNSubgraphSlicer::MKLDNNGroupStart(n)) {
1130 return true;
1131 }
1132 }
1133 return false;
1134}
1135
1136} // namespace
1137
1138void ConvertFrozenOpsToMKLDNN(std::shared_ptr<Graph>& graph) {
1139 GRAPH_DUMP("Before convert frozen ops to mkldnn", graph);
1140 // TODO: replace conv1d with conv2d ?
1141 graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
1142 if (containsMKLDNNGroup(graph->block())) {
1143 // Only remove tensor mutation if we know we're going to create speedups
1144 // with mkldnn. Only supporting functional ops simplifies this pass bc
1145 // running an op in mkldnn removes the aliasing relationships that
1146 // previously existed between input and output.
1147 RemoveTensorMutation(graph, [](Node* node_to_functionalize) {
1148 static std::unordered_set<Symbol> mkldnn_ops = {
1149 aten::add_,
1150 aten::mul_,
1151 aten::relu_,
1152 aten::relu6_,
1153 aten::gelu_,
1154 aten::hardswish_,
1155 aten::dropout_,
1156 aten::sigmoid_,
1157 aten::hardsigmoid_,
1158 aten::hardtanh_,
1159 aten::tanh_,
1160 aten::clamp_,
1161 };
1162 return mkldnn_ops.count(node_to_functionalize->kind()) != 0;
1163 });
1164
1165 AliasDb db(graph);
1166 MKLDNNSubgraphSlicer(graph->block(), graph, db).run();
1167 EliminateDeadCode(graph);
1168 GRAPH_DUMP("After convert frozen ops to mkldnn", graph);
1169 } else {
1170 GRAPH_DUMP("No mkldnn compatible frozen nodes", graph);
1171 }
1172}
1173
1174#else
1175
1176void ConvertFrozenOpsToMKLDNN(std::shared_ptr<Graph>& graph) {
1177 GRAPH_DUMP("MKLDNN Not enabled", graph);
1178}
1179
1180#endif
1181
1182} // namespace jit
1183} // namespace torch
1184