1#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
2
3#include <ATen/core/interned_strings.h>
4#include <ATen/core/symbol.h>
5#include <ATen/record_function.h>
6#include <c10/util/FunctionRef.h>
7#include <c10/util/irange.h>
8#include <torch/csrc/jit/codegen/cuda/interface.h>
9#include <torch/csrc/jit/codegen/fuser/interface.h>
10#include <torch/csrc/jit/ir/alias_analysis.h>
11#include <torch/csrc/jit/jit_log.h>
12#include <torch/csrc/jit/jit_opt_limit.h>
13#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
14#include <torch/csrc/jit/passes/constant_pooling.h>
15#include <torch/csrc/jit/passes/dead_code_elimination.h>
16#include <torch/csrc/jit/passes/pass_manager.h>
17#include <torch/csrc/jit/passes/remove_redundant_profiles.h>
18#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
19#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
20#include <torch/csrc/jit/runtime/custom_operator.h>
21#include <torch/csrc/jit/runtime/graph_executor.h>
22#include <torch/csrc/jit/runtime/operator_options.h>
23#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
24#include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
25#include <torch/csrc/jit/tensorexpr/kernel.h>
26#include <torch/csrc/utils/memory.h>
27
28#include <utility>
29
30// NOLINTNEXTLINE
31C10_DEFINE_bool(
32 torch_jit_disable_cat,
33 false,
34 "disable aten::cat in TE fusion groups");
35
36C10_DEFINE_bool(
37 torch_jit_enable_dynamic_shape_fusion,
38 false,
39 "enable TE fusion using dynamic shapes");
40
41namespace torch {
42namespace jit {
43
44static bool texpr_reductions_enabled = false;
45
46bool isSupportedForBlock(Node* node) {
47 switch (node->kind()) {
48 case aten::add:
49 case aten::mul:
50 return true;
51 default:
52 return false;
53 }
54}
55
56bool usedOnlyInSize(Value* v) {
57 const auto& uses = v->uses();
58 return std::all_of(uses.begin(), uses.end(), [](const Use& u) {
59 return u.user->matches("aten::size(Tensor self) -> int[]");
60 });
61}
62
63Value* broadcastSizes(at::ArrayRef<Value*> sizes, AliasDb* db) {
64 AT_ASSERT(!sizes.empty());
65 Graph* graph = sizes[0]->owningGraph();
66 Node* broadcast_n =
67 graph->insertNode(graph->create(prim::BroadcastSizes, sizes));
68 broadcast_n->output()->setType(ListType::ofInts());
69 db->createValue(broadcast_n->output());
70 return broadcast_n->output();
71}
72
73namespace tensorexpr {
74
75OperatorSet& getCustomOperatorSet() {
76 static OperatorSet _g_custom_operator_set{};
77 return _g_custom_operator_set;
78}
79
80static const OperatorSet& supported_non_eltwise_set() {
81 // clang-format off
82 static const OperatorSet supported_non_eltwise_set{
83 "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
84 "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
85 "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
86 "aten::matmul(Tensor self, Tensor other) -> Tensor",
87 };
88 // clang-format on
89 return supported_non_eltwise_set;
90};
91
92bool isSupported(Node* node) {
93 // For Block codegen we allow limited ops.
94 if (tensorexpr::getTEGenerateBlockCode()) {
95 return isSupportedForBlock(node);
96 }
97
98 static const OperatorSet supported_reduction_set{
99 "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor",
100 "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
101 "aten::softmax.int(Tensor self, int dim , ScalarType? dtype=None) -> Tensor",
102 "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor",
103 };
104 static const OperatorSet supported_misc_set{
105 "aten::cat(Tensor[] tensors, int dim=0) -> Tensor",
106 "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)",
107 };
108 // clang-format on
109
110 if (get_tensorexpr_elementwise_set().contains(node) ||
111 node->isMemberOf(supported_non_eltwise_set()) ||
112 node->isMemberOf(supported_misc_set) ||
113 node->isMemberOf(getCustomOperatorSet()) ||
114 (texpr_reductions_enabled && node->isMemberOf(supported_reduction_set))) {
115 // We only insert guards on Tensor types, so we rely on the output
116 // of a node being uniquely determined by its input types.
117 // bail if any non-Tensor input affects the output type
118 // and cannot be reasoned about statically
119
120 // Value is either an int or a float (can occur from .item())
121 for (Value* v : node->inputs()) {
122 if (v->type()->cast<NumberType>()) {
123 return false;
124 }
125 }
126
127 // non-const dtype / device
128 for (auto arg_name : {"dtype", "device"}) {
129 if (auto index = node->schema().argumentIndexWithName(arg_name)) {
130 if (!toIValue(node->input(*index))) {
131 return false;
132 }
133 }
134 }
135
136 if (FLAGS_torch_jit_disable_cat && node->kind() == aten::cat) {
137 return false;
138 }
139
140 return true;
141 }
142
143 // unschematized ops
144 switch (node->kind()) {
145 case prim::ConstantChunk:
146 case prim::ListConstruct:
147 case prim::TensorExprGroup:
148 return true;
149 }
150
151 return false;
152}
153} // namespace tensorexpr
154
155static bool texpr_fuser_enabled_ = true;
156
157void setTensorExprFuserEnabled(bool val) {
158 texpr_fuser_enabled_ = val;
159}
160
161bool tensorExprFuserEnabled() {
162 static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR");
163 if (!enable_c_str) {
164 return texpr_fuser_enabled_;
165 }
166 if (std::string(enable_c_str) == "0") {
167 return false;
168 }
169 return true;
170}
171
172bool tensorExprDynamicShapeFusionEnabled() {
173 return FLAGS_torch_jit_enable_dynamic_shape_fusion;
174}
175
176void setTensorExprDynamicShapeFusionEnabled(bool val) {
177 FLAGS_torch_jit_enable_dynamic_shape_fusion = val;
178}
179
180bool setTexprReductionsEnabled(bool value) {
181 bool old_value = texpr_reductions_enabled;
182 texpr_reductions_enabled = value;
183 return old_value;
184}
185
186bool texprReductionsEnabled() {
187 return texpr_reductions_enabled;
188}
189
190void removeProfileNodesAndSpecializeTypes(Block* b) {
191 for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
192 if (it->kind() == prim::profile) {
193 GRAPH_DEBUG("Removing prim::profile: %", it->output()->debugName());
194 it->output()->replaceAllUsesWith(it->input());
195 auto profiled_type = it->ty(attr::profiled_type)->expect<TensorType>();
196
197 TensorTypePtr input_tensor_type = nullptr;
198 bool input_is_optional = false;
199 if (it->input()->type()->kind() == c10::TypeKind::TensorType) {
200 input_tensor_type = it->input()->type()->expect<TensorType>();
201 } else {
202 input_tensor_type = it->input()
203 ->type()
204 ->expectRef<OptionalType>()
205 .getElementType()
206 ->expect<TensorType>();
207 input_is_optional = true;
208 }
209
210 if (input_is_optional) {
211 it.destroyCurrent();
212 continue;
213 }
214
215 // A value can be profiled with differently typed uses.
216 // This can occur from:
217 // - having a use which is not executed, so the type will be
218 // TensorType::get()
219 // - control-flow that depends on tensor type:
220 // if x.size() == 2 op(x) else op(x)
221 // - mutation of the value on a field represented in the tensor type
222 // op(x); x.resize_([...]); op(x)
223
224 // The most common case today with num_profiles = 1 is from the first
225 // case. Here we can just ignore non-profiled uses, and choose any of the
226 // profiled uses. Because we guard all tensor types in the runtime, even
227 // if we set a Value to have a profiled type from one use and then execute
228 // a use with a different profiled type, we will still be correct.
229 // In the future we could consider unifying the types of uses, or adding a
230 // type refinement node so uses can have the correct corresponding type.
231 if (profiled_type == TensorType::get()) {
232 continue;
233 }
234
235 // If we encounter non-identical profiled types for the same value, merge
236 // them. This situation can happen if, e.g., loop unrolling duplicates
237 // profiled types in a loop body in a manner that isn't logically
238 // consistent (see TestTEFuser.test_unrolled_cat).
239 if (input_tensor_type == TensorType::get()) {
240 it->input()->setType(profiled_type);
241 } else {
242 it->input()->setType(input_tensor_type->merge(*profiled_type));
243 }
244
245 it.destroyCurrent();
246 } else {
247 for (Block* ib : it->blocks()) {
248 removeProfileNodesAndSpecializeTypes(ib);
249 }
250 }
251 }
252}
253
254void RemoveProfileNodesAndSpecializeTypes(std::shared_ptr<Graph>& graph) {
255 GRAPH_DEBUG("Before removeProfileNodesAndSpecializeTypes:\n", *graph);
256 removeProfileNodesAndSpecializeTypes(graph->block());
257 GRAPH_DEBUG("After removeProfileNodesAndSpecializeTypes:\n", *graph);
258}
259
260bool hasTensorTypeSpecialization(Value* v) {
261 if (!v->type()->cast<TensorType>()) {
262 return false;
263 }
264 // Constants & TensorExprGroup will always produce specialized tensor type,
265 // TypeCheck are inserted by this pass and only used by fusion groups that
266 // insert proper guards
267 if (v->node()->kind() == prim::Constant ||
268 v->node()->kind() == prim::TypeCheck ||
269 v->node()->kind() == prim::TensorExprGroup) {
270 return false;
271 }
272 if (v->type() == TensorType::get()) {
273 return false;
274 }
275 return true;
276}
277
278void removeTensorTypeSpecialization(Value* v) {
279 if (hasTensorTypeSpecialization(v)) {
280 v->setType(TensorType::get());
281 }
282}
283
284void removeTensorTypeSpecializations(Block* block) {
285 for (Value* v : block->inputs()) {
286 removeTensorTypeSpecialization(v);
287 }
288 for (Node* n : block->nodes()) {
289 for (Block* b : n->blocks()) {
290 removeTensorTypeSpecializations(b);
291 }
292 for (Value* v : n->outputs()) {
293 removeTensorTypeSpecialization(v);
294 }
295 }
296}
297
298void RemoveTensorTypeSpecializations(std::shared_ptr<Graph>& graph) {
299 removeTensorTypeSpecializations(graph->block());
300}
301
302void insertTypeGuard(
303 Node* guarded_node,
304 tensor_type_converter_t type_converter,
305 Symbol kind) {
306 GRAPH_DEBUG("Inserting a typecheck guard for a node", *guarded_node);
307 auto subgraph = SubgraphUtils::getSubgraph(guarded_node);
308
309 // Fixup types of the subgraph inputs
310 std::vector<Value*> inputs_to_check;
311 std::vector<TypePtr> guard_types;
312 for (Value* input : guarded_node->inputs()) {
313 // We only check inputs of the guarded nodes and expect user to infer
314 // intermediates and outputs shapes
315 if (!input->type()->cast<TensorType>()) {
316 continue;
317 }
318
319 // fusion outputs are already guarded
320 if (input->node()->kind() == prim::Constant ||
321 input->node()->kind() == prim::FusionGroup) {
322 continue;
323 }
324 inputs_to_check.push_back(input);
325 guard_types.emplace_back(
326 type_converter(input->type()->expect<TensorType>()));
327 }
328 if (inputs_to_check.empty()) {
329 return;
330 }
331
332 // Add prim::TypeCheck node
333 //
334 // TypeCheck nodes look like the following:
335 // %out1 : Float(2, 3), %out2 : Int(10, 30), %types_match : bool =
336 // prim::TypeCheck(%inp1 : Tensor, %inp2 : Tensor)
337 //
338 // They have N inputs whose types we are going to check and N+1 outputs. The
339 // first N outputs specify expected types and N+1-th output holds the result
340 // of the check (bool).
341 Node* typecheck_node =
342 guarded_node->owningGraph()
343 ->create(kind, inputs_to_check, inputs_to_check.size() + 1)
344 ->insertBefore(guarded_node);
345 typecheck_node->tys_(attr::types, std::move(guard_types));
346 Value* typecheck_result = typecheck_node->output(inputs_to_check.size());
347
348 std::unordered_map<Value*, Value*> typechecked_inputs;
349 for (size_t i = 0; i < typecheck_node->inputs().size(); ++i) {
350 typechecked_inputs[typecheck_node->input(i)] = typecheck_node->output(i);
351 }
352
353 // Fixup types of the typecheck node outputs, which are used by the op in
354 // execution
355 typecheck_node->output(inputs_to_check.size())->setType(BoolType::get());
356 for (size_t i = 0; i < typecheck_node->inputs().size(); ++i) {
357 typecheck_node->output(i)->setType(typecheck_node->input(i)->type());
358 }
359
360 // Insert if
361 auto versioning_if =
362 guarded_node->owningGraph()
363 ->create(prim::If, {typecheck_result}, guarded_node->outputs().size())
364 ->insertAfter(typecheck_node);
365 for (size_t idx = 0; idx < guarded_node->outputs().size(); ++idx) {
366 versioning_if->output(idx)->setType(guarded_node->output(idx)->type());
367 guarded_node->output(idx)->replaceAllUsesWith(versioning_if->output(idx));
368 }
369 auto true_block = versioning_if->addBlock();
370 auto false_block = versioning_if->addBlock();
371
372 // Fill in the false block. It should contain the unoptimized
373 // copy of the fused subgraph.
374 WithInsertPoint guard(false_block->return_node());
375 const auto subgraph_outputs = insertGraph(
376 *guarded_node->owningGraph(), *subgraph, guarded_node->inputs());
377 for (Value* output : subgraph_outputs) {
378 false_block->registerOutput(output);
379 }
380
381 // types get copied to the fallback graph, so remove specializations before
382 // replacing
383 removeTensorTypeSpecializations(false_block);
384 replaceBlockWithFallbackGraph(false_block, guarded_node->inputs());
385
386 // Fill in the true block. It has all inputs type-checked and its
387 // body should be the fusion group node.
388 guarded_node->moveBefore(true_block->return_node());
389 for (size_t idx = 0; idx < guarded_node->inputs().size(); ++idx) {
390 if (typechecked_inputs.count(guarded_node->input(idx))) {
391 guarded_node->replaceInput(
392 idx, typechecked_inputs.at(guarded_node->input(idx)));
393 }
394 }
395 for (Value* output : guarded_node->outputs()) {
396 true_block->registerOutput(output);
397 }
398}
399
400namespace {
401bool has_unsupported_pin_memory(const Node* node) {
402 // cant support non-constant pin_memory or pin_memory = True
403 if (auto maybe_index = node->schema().argumentIndexWithName("pin_memory")) {
404 int index = *maybe_index;
405 auto inp = node->input(index);
406 if (inp->type() != NoneType::get() &&
407 constant_as<bool>(inp).value_or(true)) {
408 return true;
409 }
410 }
411 return false;
412}
413} // namespace
414
415class TensorExprFuser {
416 public:
417 TensorExprFuser(
418 std::shared_ptr<Graph> graph,
419 size_t min_group_size,
420 bool add_composed_op,
421 bool fuse_to_dynamic_shapes)
422 : graph_(std::move(graph)),
423 min_group_size_(min_group_size),
424 add_composed_op_(add_composed_op),
425 fuse_to_dynamic_shapes_(fuse_to_dynamic_shapes) {
426 parseTENotFuseOption();
427 }
428
429 // Builds up expressions that compute shapes of all intermediates (and
430 // outputs) of the fusion group, based on the sizes of inputs. You should run
431 // DCE to remove those that you end up not using.
432 std::unordered_map<Value*, Value*> buildShapeExpressions(Node* fusion_group) {
433 GRAPH_DUMP("buildShapeExpressions for ", fusion_group->g(attr::Subgraph));
434 WithInsertPoint insert_guard{fusion_group->next()};
435 std::unordered_map<Value*, Value*> shape_of;
436
437 Graph* graph = fusion_group->owningGraph();
438 auto subgraph = fusion_group->g(attr::Subgraph);
439
440 auto inputs = fusion_group->inputs();
441 auto sinputs = subgraph->inputs();
442 AT_ASSERT(inputs.size() == sinputs.size());
443 for (const auto i : c10::irange(inputs.size())) {
444 if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) {
445 Value* soutput = graph->insert(aten::size, {inputs[i]});
446 aliasDb_->createValue(soutput);
447 GRAPH_DEBUG(
448 "Adding a mapping for %",
449 sinputs[i]->debugName(),
450 " ",
451 getHeader(soutput->node()));
452 shape_of[sinputs[i]] = soutput;
453 }
454 }
455
456 // When we have a guarantee that an output won't be removed, because it's
457 // used in expressions that don't involve size checks, we can use its size
458 // instead of computing a long chain of broadcasts, starting from the
459 // beginning of the kernel.
460 auto outputs = fusion_group->outputs();
461 auto soutputs = subgraph->outputs();
462 AT_ASSERT(outputs.size() == soutputs.size());
463 for (const auto i : c10::irange(outputs.size())) {
464 if (usedOnlyInSize(outputs[i]))
465 continue;
466 Value* soutput = graph->insert(aten::size, {outputs[i]});
467 aliasDb_->createValue(soutput);
468 shape_of[soutputs[i]] = soutput;
469 }
470
471 for (Node* n : subgraph->nodes()) {
472 auto tensor_inputs = filter(n->inputs(), [](Value* v) {
473 return v->type()->isSubtypeOf(*TensorType::get());
474 });
475 GRAPH_DEBUG("Building sizes for ", getHeader(n));
476 bool all_inputs_have_sizes = true;
477 auto shapes = fmap(tensor_inputs, [&](Value* v) {
478 GRAPH_DEBUG("Getting aten::size for %", v->debugName());
479 all_inputs_have_sizes &= shape_of.count(v);
480 return shape_of.count(v) != 0 ? shape_of.at(v) : nullptr;
481 });
482 if (!all_inputs_have_sizes) {
483 GRAPH_DEBUG(
484 "Not all tensor arguments have sizes available to compute the broadcasted size",
485 getHeader(n));
486 continue;
487 }
488
489 if (n->kind() == prim::ConstantChunk) {
490 Node* sizes_node = graph->insertNode(
491 graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
492 sizes_node->i_(attr::dim, n->i(attr::dim));
493 sizes_node->i_(attr::chunks, n->i(attr::chunks));
494 for (Value* output : sizes_node->outputs()) {
495 aliasDb_->createValue(output);
496 }
497 Value* regular_size = sizes_node->outputs().at(0);
498 Value* last_size = sizes_node->outputs().at(1);
499 regular_size->setType(ListType::ofInts());
500 last_size->setType(ListType::ofInts());
501 auto outputs = n->outputs();
502 for (Value* o : outputs.slice(0, outputs.size() - 1)) {
503 shape_of.emplace(o, regular_size);
504 }
505 shape_of.emplace(outputs.at(outputs.size() - 1), last_size);
506 continue;
507 }
508
509 // we only support shape calculations for elementwise, some
510 // non-elementwise like batch_norm, conv, matmul, and
511 // a few exceptions (e.g. prim::ConstantChunk, etc) listed above
512 if (!(get_tensorexpr_elementwise_set().contains(n)) &&
513 !n->isMemberOf(tensorexpr::supported_non_eltwise_set())) {
514 continue;
515 }
516
517 shape_of.emplace(
518 n->output(),
519 shapes.size() == 1 ? shapes[0]
520 : broadcastSizes(shapes, aliasDb_.get()));
521 }
522 return shape_of;
523 }
524
525 void removeOutputsUsedOnlyInSize(Node* fusion_group) {
526 if (fusion_group->kind() != prim::TensorExprGroup)
527 return;
528 auto subgraph = fusion_group->g(attr::Subgraph);
529
530 auto shape_of = buildShapeExpressions(fusion_group);
531 auto outputs = fusion_group->outputs().vec();
532 auto soutputs = subgraph->outputs().vec();
533 // XXX: Iterating in this order is not only good for performance reasons!
534 // It is also crucial for correctness (i has to reflect the current true
535 // index of outputs[i])!
536 for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) {
537 auto output = outputs[i];
538 auto soutput = soutputs[i];
539 if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) {
540 auto uses = output->uses();
541 for (Use u : uses) {
542 AT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]"));
543 u.user->output()->replaceAllUsesWith(shape_of.at(soutput));
544 u.user->destroy();
545 }
546 fusion_group->eraseOutput(i);
547 subgraph->eraseOutput(i);
548 }
549 }
550 }
551
552 void run() {
553 aliasDb_ = torch::make_unique<AliasDb>(graph_);
554 RemoveRedundantProfiles(graph_);
555 GRAPH_DUMP("After removing redundant profile nodes: ", graph_);
556 createFusionGroups(graph_->block());
557 GRAPH_DUMP("After creating fusion groups: ", graph_);
558 // we maintain alias db correctness during initial fusion, but it is
559 // difficult to maintain correctness after inlining so inline only after
560 // fusion is done.
561 inlineSmallFusionGroups(graph_->block());
562 GRAPH_DUMP("After inlining small fusion groups: ", graph_);
563 if (fuse_to_dynamic_shapes_) {
564 VLOG(1) << "TensorExpr fusion with dynamic shapes is enabled"
565 << std::endl;
566 generalizeFusionGroups(graph_->block());
567 GRAPH_DUMP("After generalizing fusion groups: ", graph_);
568 } else {
569 prepareFusionGroupAndGuardOutputs(graph_->block());
570 GRAPH_DUMP("After guarding fusion groups: ", graph_);
571 }
572 }
573
574 private:
575 Node* getOrCreateTensorExprSubgraph(Node* n) {
576 if (n->hasAttribute(attr::Subgraph) && n->kind() == prim::TensorExprGroup) {
577 return n;
578 }
579 GRAPH_UPDATE("Creating a tensorexpr::Group node from: ", *n);
580 return SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
581 n, prim::TensorExprGroup, *aliasDb_);
582 }
583
584 value_list sortReverseTopological(ArrayRef<Value*> inputs, Block* b) {
585 value_list result;
586 for (auto i : inputs) {
587 if (i->node()->owningBlock() == b) {
588 result.push_back(i);
589 }
590 }
591 // Sort in reverse topological order
592 std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
593 return a->node()->isAfter(b->node());
594 });
595 return result;
596 }
597
598 // Create a fusion group starting from the node N.
599 // We then try to pull inputs into the fusion group and repeat that process
600 // until there is nothing we can pull in.
601 std::pair<graph_node_list::iterator, bool> createFusionGroup(
602 Node* fusion_node) {
603 // Allow single-node groups containing conv2d, since we'll only select
604 // those in cases where the tensorexpr implementation is faster than the
605 // aten implementation.
606 if (min_group_size_ == 1 || fusion_node->kind() == aten::conv2d) {
607 fusion_node = getOrCreateTensorExprSubgraph(fusion_node);
608 }
609
610 GRAPH_DEBUG("Iteratively pull input nodes into the fusion group...\n");
611 auto inputs = sortReverseTopological(
612 fusion_node->inputs(), fusion_node->owningBlock());
613 for (auto input : inputs) {
614 debugDumpFusionGroup("Current fusion group: ", fusion_node);
615 GRAPH_DEBUG("Trying to merge: ", *input->node());
616 if (auto maybe_fusion_group = tryMerge(fusion_node, input->node())) {
617 // we successfully merged, so the new group's `inputs` may have
618 // changed. So rescan the new group for more merging opportunities.
619 return std::make_pair(
620 maybe_fusion_group.value()->reverseIterator(), true);
621 }
622 }
623
624 return std::make_pair(++fusion_node->reverseIterator(), false);
625 }
626
627 static void debugDumpFusionGroup(const std::string& msg, Node* n) {
628 // NOLINTNEXTLINE(clang-analyzer-core.NonNullParamChecker)
629 GRAPH_DEBUG(msg, *n);
630 // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
631 if (n->kind() == prim::TensorExprGroup) {
632 GRAPH_DEBUG(*n->g(attr::Subgraph));
633 }
634 }
635
636 // No Ops in eager shouldn't be outputs of Fusion Groups because it
637 // will degrade perf and change aliasing relationships
638 static bool unexecutedEagerOp(Node* n) {
639 if (n->kind() != aten::to &&
640 n->kind() != aten::_autocast_to_reduced_precision &&
641 n->kind() != aten::_autocast_to_full_precision) {
642 return false;
643 }
644
645 return *n->input(0)->type()->expect<TensorType>() ==
646 *n->output()->type()->expect<TensorType>();
647 }
648
649 std::pair<graph_node_list::iterator, bool> scanNode(Node* n) {
650 GRAPH_DEBUG("Considering node:", *n)
651
652 if (!canHandle(n)) {
653 return std::make_pair(++n->reverseIterator(), false);
654 }
655 // There are some nodes that we can support, but we don't want to start a
656 // fusion group from - skip them.
657 if (n->kind() == prim::ListConstruct || n->kind() == aten::slice ||
658 n->kind() == aten::unsqueeze || n->kind() == prim::ConstantChunk ||
659 n->kind() == prim::Constant || unexecutedEagerOp(n)) {
660 return std::make_pair(++n->reverseIterator(), false);
661 }
662 return createFusionGroup(n);
663 }
664
665 // Merge fusible nodes into subgraphs in prim::TensorExprGroup nodes.
666 void createFusionGroups(Block* block) {
667 bool any_changed = true;
668 while (any_changed) {
669 any_changed = false;
670 for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) {
671 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
672 bool changed;
673 std::tie(it, changed) = scanNode(*it);
674 any_changed |= changed;
675 }
676 }
677
678 for (Node* n : block->nodes()) {
679 for (Block* b : n->blocks()) {
680 createFusionGroups(b);
681 }
682 }
683
684 // Try to merge adjacent fusion groups together. Because we have only merged
685 // by looking at graph inputs, without this we would not attempt to merge
686 // adjacent fusion groups that don't have a depdency on each other
687
688 std::vector<Node*> initial_fusion_groups;
689 for (Node* n : block->nodes()) {
690 if (n->kind() == prim::TensorExprGroup) {
691 initial_fusion_groups.push_back(n);
692 }
693 }
694
695 Node* prev_fusion_group =
696 !initial_fusion_groups.empty() ? initial_fusion_groups[0] : nullptr;
697
698 for (const auto i : c10::irange(1, initial_fusion_groups.size())) {
699 // Try merging the just created fusion group into the previous one.
700 // If it did not work, then put the previous fusion group into
701 // fusion_groups vector - we will not touch it anymore in this loop.
702 // If merging suceeded, save the merged group as the "previous" fusion
703 // group so that we can try to merge the next one into it.
704
705 Node* fusion_group = initial_fusion_groups[i];
706 debugDumpFusionGroup(
707 "Trying to merge into the previous fusion group: ",
708 prev_fusion_group);
709 if (auto merged_fusion_group =
710 tryMerge(prev_fusion_group, fusion_group)) {
711 prev_fusion_group = *merged_fusion_group;
712 debugDumpFusionGroup(
713 "Successfully merged into the previous fusion group: ",
714 prev_fusion_group);
715 } else {
716 GRAPH_DEBUG("Cannot merge into the previous fusion group");
717 prev_fusion_group = fusion_group;
718 }
719 }
720 }
721
722 size_t blockSize(Block* block) {
723 size_t num = 0;
724 for (Node* n : block->nodes()) {
725 // Don't count prim::Constants and prim::ListConstructs as these are nodes
726 // we only pull in along with another, "main", node. E.g. the
727 // ListConstruct nodes would also be pulled into a fusion group if they
728 // are inputs of an aten::cat node.
729 if (n->kind() == prim::Constant || n->kind() == prim::ListConstruct) {
730 continue;
731 }
732 for (Block* b : n->blocks()) {
733 num += blockSize(b);
734 }
735 num++;
736 }
737 return num;
738 }
739
740 bool hasConv(Block* block) {
741 for (Node* n : block->nodes()) {
742 if (n->kind() == aten::conv2d) {
743 return true;
744 }
745 }
746 return false;
747 }
748
749 bool inlineIfTooSmall(Node* n) {
750 if (n->kind() != prim::TensorExprGroup) {
751 return false;
752 }
753 auto subgraph = SubgraphUtils::getSubgraph(n);
754 size_t num_nodes = blockSize(subgraph->block());
755 // Allow small subgraphs containing conv2d, since we'll only select those
756 // in cases where the tensorexpr implementation is faster than the aten
757 // implementation.
758 if (num_nodes < min_group_size_ && !hasConv(subgraph->block())) {
759 GRAPH_UPDATE("Fusion group is too small, unmerging: ", *n);
760 SubgraphUtils::unmergeSubgraph(n);
761 return true;
762 }
763 // Cleanup the subgraph from duplicated constants while we're at it.
764 ConstantPooling(subgraph);
765
766 if (GRAPH_DEBUG_ENABLED) {
767 GRAPH_EXPORT("", subgraph);
768 }
769 return false;
770 }
771
772 void inlineSmallFusionGroups(Block* block) {
773 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
774 Node* n = *it;
775 it++;
776
777 for (Block* b : n->blocks()) {
778 inlineSmallFusionGroups(b);
779 }
780 inlineIfTooSmall(n);
781 }
782 }
783
784 c10::optional<Node*> tryMerge(Node* fusion_group, Node* to_merge) {
785 if (!canMerge(fusion_group, to_merge)) {
786 return c10::nullopt;
787 }
788
789 std::vector<Node*> nodes_to_merge = {to_merge};
790
791 if (to_merge->kind() == aten::cat) {
792 Node* listconstruct = to_merge->input(0)->node();
793 nodes_to_merge.push_back(listconstruct);
794 }
795
796 // First, try to move all the nodes we want to fuse next to the fusion
797 // group.
798 Node* move_point = fusion_group;
799 for (auto n : nodes_to_merge) {
800 GRAPH_UPDATE("Trying to move node next to fusion group: ", getHeader(n));
801 if (!aliasDb_->moveBeforeTopologicallyValid(n, move_point)) {
802 GRAPH_UPDATE("Failed to move because of AliasDB checks!");
803 return c10::nullopt;
804 }
805 move_point = n;
806 }
807
808 // Now all the nodes that we're going to fuse are moved next to the fusion
809 // group, so we can safely merge them into the fusion group subgraph.
810 fusion_group = getOrCreateTensorExprSubgraph(fusion_group);
811
812 for (auto n : nodes_to_merge) {
813 GRAPH_UPDATE("Merging ", getHeader(n));
814 SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
815 n, fusion_group, *aliasDb_);
816 }
817 return fusion_group;
818 }
819
820 bool shapeIsKnown(Value* v) {
821 if (v->type()->cast<TensorType>()) {
822 if (!v->isCompleteTensor()) {
823 return false;
824 }
825 }
826 return true;
827 }
828
829 bool allShapesAreKnown(Node* node) {
830 // TODO: Relax the checks to support dynamic shapes
831 for (Value* input : node->inputs()) {
832 if (!shapeIsKnown(input)) {
833 return false;
834 }
835 if (input->node()->kind() == prim::ListConstruct) {
836 if (!allShapesAreKnown(input->node())) {
837 return false;
838 }
839 }
840 }
841 for (Value* output : node->outputs()) {
842 if (!shapeIsKnown(output)) {
843 return false;
844 }
845 }
846 return true;
847 }
848
849 bool canFuseOnDevice(Value* v) {
850 auto type = v->type()->cast<TensorType>();
851 if (!type) {
852 return true;
853 }
854 auto device = type->device();
855 if (!device) {
856 return false;
857 }
858 if (device->is_cpu()) {
859 return canFuseOnCPU();
860 } else if (device->is_cuda()) {
861#ifndef C10_MOBILE
862 if (fuser::cuda::isEnabled()) {
863 return false;
864 }
865#endif
866 return canFuseOnGPU();
867 } else if (device->is_xpu()) {
868 return false;
869 }
870 return false;
871 }
872
873 bool isFusableOnDevice(Node* node) {
874 for (const auto& input : node->inputs()) {
875 if (input->node()->kind() == prim::ListConstruct) {
876 if (!isFusableOnDevice(input->node())) {
877 return false;
878 }
879 }
880 if (!canFuseOnDevice(input)) {
881 return false;
882 }
883 }
884 return true;
885 }
886
887 bool typesAreSupported(Node* node) {
888 // clang-format off
889 // breaks up the schema strings so they are no longer discoverable with ctrl-F
890 static const OperatorSet float_only_operator_set{
891 "aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor",
892 "aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor",
893 "aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor",
894 "aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor",
895 };
896 static const OperatorSet int_only_operator_set{
897 "aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor",
898 "aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor",
899 "aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor",
900 "aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor",
901 };
902 static const OperatorSet cpu_compute_heavy_set{
903 "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
904 "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
905 "aten::matmul(Tensor self, Tensor other) -> Tensor",
906 };
907 static const OperatorSet gpu_only_operator_set{
908 // On CPU, these are slower and less accurate than ATen kernels, because
909 // ATen is able to use MKL-VML, whereas the fuser currently can't. The
910 // fuser uses sleef instead because sleef provides functions that operate
911 // on vectors, instead of large buffers.
912 "aten::erf(Tensor self) -> Tensor",
913 "aten::erfc(Tensor self) -> Tensor",
914 };
915 static const OperatorSet pow{
916 "aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor",
917 };
918 // clang-format on
919
920 // Check types of input values.
921 for (const Value* v : node->inputs()) {
922 if (auto const& tt = v->type()->cast<TensorType>()) {
923 auto const& st = tt->scalarType();
924 auto const& device = tt->device();
925
926 // All tensors must be typed.
927 if (!st || !device) {
928 return false;
929 }
930
931 // Byte tensors introduce too many corner cases in type promotion.
932 // Better not to try to handle them.
933 if (*st == c10::ScalarType::Byte) {
934 return false;
935 }
936
937 // Float16 support has some issues (see e.g. #61336 and #61382), so for
938 // now it's disabled. There seem to be some problems in HalfRewriter,
939 // but on top of that Float16 has a few kinks on LLVM. Thus, on CPU we
940 // additionally disable it until we either move to a more stable version
941 // or find workarounds.
942 if (*st == c10::ScalarType::Half && *device == c10::kCPU) {
943 return false;
944 }
945
946 if (*st == c10::ScalarType::BFloat16 && *device == c10::kCPU) {
947#ifndef TORCH_ENABLE_LLVM
948 return false;
949#endif
950 }
951
952 // These operators only support floats, because integer divisors need to
953 // raise ZeroDivisionError.
954 if (node->isMemberOf(float_only_operator_set) && !isFloatingType(*st)) {
955 return false;
956 }
957
958 // These operators have complicated casting rules for floats.
959 if (node->isMemberOf(int_only_operator_set) && isFloatingType(*st)) {
960 return false;
961 }
962 } else if (node->isMemberOf(float_only_operator_set)) {
963 // Check scalar operands of float-only ops.
964 if (!v->type()->cast<FloatType>()) {
965 return false;
966 }
967 } else if (node->isMemberOf(int_only_operator_set)) {
968 if (!v->type()->cast<IntType>()) {
969 return false;
970 }
971 }
972 }
973
974 // aten::pow has special rules to avoid complicated integer cases. We
975 // expect the first arg to be a floating point tensor, and if that's the
976 // case the type of the scalar exponent doesn't matter.
977 if (node->isMemberOf(pow)) {
978 auto const& tt = node->input(0)->type()->cast<TensorType>();
979 if (!tt) {
980 return false;
981 }
982 auto const& st = tt->scalarType();
983 if (!st || !isFloatingType(*st)) {
984 return false;
985 }
986 }
987
988 // Operator is only supported on CPU.
989 if (node->isMemberOf(cpu_compute_heavy_set)) {
990 if (fuse_to_dynamic_shapes_) {
991 return false;
992 }
993
994 auto device = tensorexpr::pickDeviceType(node->inputs());
995 if (!device) {
996 device = tensorexpr::pickDeviceType(node->outputs());
997 }
998 if (!device || !device->is_cpu()) {
999 return false;
1000 }
1001 }
1002
1003 // Operator is only supported on GPU.
1004 if (node->isMemberOf(gpu_only_operator_set)) {
1005 auto device = tensorexpr::pickDeviceType(node->inputs());
1006 if (!device) {
1007 device = tensorexpr::pickDeviceType(node->outputs());
1008 }
1009 if (!device || !device->is_cuda()) {
1010 return false;
1011 }
1012 }
1013
1014 if (node->kind() == aten::to) {
1015 // only support same-device conversion
1016 auto device = tensorexpr::pickDeviceType(node->inputs());
1017 auto output_device = tensorexpr::pickDeviceType(node->outputs());
1018 if (!device || !output_device || *device != *output_device) {
1019 return false;
1020 }
1021 // non_blocking only applies in cross-device conversion, which we bail on
1022 // copy arg only applies if op is a no-op, which we dont start fusion
1023 // group from memory format is separately handled in NNC output
1024
1025 // all non-Tensor arguments must be constant
1026 for (size_t i = 1; i < node->inputs().size(); i++) {
1027 if (node->inputs().at(i)->node()->kind() != prim::Constant) {
1028 return false;
1029 }
1030 }
1031
1032 if (has_unsupported_pin_memory(node)) {
1033 return false;
1034 }
1035 }
1036
1037 if (node->kind() == aten::_autocast_to_reduced_precision ||
1038 node->kind() == aten::_autocast_to_full_precision) {
1039 for (auto i : c10::irange(1, node->inputs().size())) {
1040 if (node->inputs().at(i)->node()->kind() != prim::Constant) {
1041 return false;
1042 }
1043 }
1044
1045 bool is_reduced_precision =
1046 node->kind() == aten::_autocast_to_reduced_precision;
1047 bool is_full_precision =
1048 node->kind() == aten::_autocast_to_full_precision;
1049 auto self_tensor = node->inputs()[0]; // input tensor
1050
1051 if (auto const& tt = self_tensor->type()->cast<TensorType>()) {
1052 auto st = tt->scalarType();
1053 if (!st.has_value()) {
1054 return false;
1055 }
1056
1057 auto device = tt->device();
1058 if (!device.has_value()) {
1059 return false;
1060 }
1061
1062 bool is_cpu = device->is_cpu();
1063
1064 if (*st != at::kFloat && is_reduced_precision && is_cpu) {
1065 // Regarding CPU, aten would do nothing if the data type is
1066 // float. Then the aten performance is better than NNC. So NNC
1067 // does not pull it into its fusion group.
1068 return false;
1069 }
1070
1071 if (*st != at::kBFloat16 && is_full_precision && is_cpu) {
1072 // Regarding CPU, aten would do nothing if the data type is
1073 // BFloat16. Then the aten performance is better than NNC. So NNC
1074 // does not pull it into its fusion group.
1075 return false;
1076 }
1077 }
1078
1079 if (has_unsupported_pin_memory(node)) {
1080 return false;
1081 }
1082 }
1083
1084 if (node->kind() == aten::unsqueeze) {
1085 // `dim` argument must be a constant.
1086 if (node->input(1)->node()->kind() != prim::Constant) {
1087 return false;
1088 }
1089 }
1090
1091 if (node->kind() == aten::_convolution && !tensorexpr::isConv2d(node)) {
1092 GRAPH_DEBUG("This aten::_convolution node is not a 2D conv");
1093 return false;
1094 }
1095 if (node->kind() == aten::_convolution || node->kind() == aten::conv2d) {
1096 if (!tensorexpr::conv2dIsSupportedJit(node) &&
1097 !tensorexpr::mkldnnPrepackedConvIsSupportedJit(node)) {
1098 GRAPH_DEBUG("Params of conv2d are not supported");
1099 return false;
1100 }
1101 }
1102 if (node->kind() == aten::matmul) {
1103 if (!tensorexpr::matmulIsSupported(node)) {
1104 GRAPH_DEBUG("Shapes of matmul inputs are not supported");
1105 return false;
1106 }
1107 }
1108 return true;
1109 }
1110
1111#define REQ(cond) \
1112 if (!(cond)) { \
1113 GRAPH_DEBUG("Failed cond " #cond "\n"); \
1114 return false; \
1115 }
1116
1117 bool canHandle(Node* node) {
1118 REQ(allShapesAreKnown(node));
1119 REQ(isFusableOnDevice(node));
1120 REQ(operators_not_to_fuse.find(node->kind()) ==
1121 operators_not_to_fuse.end());
1122
1123 for (Value* input : node->inputs()) {
1124 if (auto const& tt = input->type()->cast<TensorType>()) {
1125 auto st = tt->scalarType();
1126 if (!st) {
1127 // All tensor types should be known.
1128 return false;
1129 }
1130 if (c10::isComplexType(*st) || c10::isQIntType(*st)) {
1131 return false;
1132 }
1133 }
1134 }
1135 if (node->kind() == aten::cat) {
1136 REQ(node->input(0)->node()->kind() == prim::ListConstruct);
1137 REQ(node->input(0)->uses().size() == 1);
1138 REQ(node->input(1)->node()->kind() == prim::Constant);
1139 auto const& listconstruct = node->input(0)->node();
1140 REQ(tensorexpr::pickDeviceType(listconstruct->inputs()));
1141 } else {
1142 REQ(tensorexpr::pickDeviceType(node->inputs()));
1143 }
1144
1145 // Only fuse aten::batch_norm when the parameter 'training' is false
1146 if (node->kind() == aten::batch_norm) {
1147 REQ(node->input(5)->node()->kind() == prim::Constant);
1148 REQ(!toIValue(node->input(5)).value().toBool());
1149 }
1150
1151 REQ(tensorexpr::isSupported(node));
1152 REQ(typesAreSupported(node));
1153
1154 // A hook to optimizations limitter to allow bisecting the pass
1155 REQ(JIT_OPT_ALLOWED);
1156
1157 if (fuse_to_dynamic_shapes_) {
1158 // Allow only if the node has a shape function defined.
1159 // ListConstruct node is an exception since that is needed to fuse
1160 // aten::cat, though it does not have a shape function.
1161 REQ(node->kind() == prim::ListConstruct ||
1162 node->kind() == prim::TensorExprGroup ||
1163 node->isMemberOf(tensorexpr::getCustomOperatorSet()) ||
1164 (node->maybeSchema() && shapeComputeGraphForSchema(node->schema())));
1165 }
1166
1167 return true;
1168 }
1169
1170 bool canMerge(Node* consumer, Node* producer) {
1171 // Only fuse within a block
1172 REQ(consumer->owningBlock() == producer->owningBlock());
1173
1174 // Symbolic checks
1175 REQ(canHandle(producer) || producer->kind() == prim::TensorExprGroup);
1176 TORCH_INTERNAL_ASSERT(
1177 consumer->kind() == prim::TensorExprGroup || canHandle(consumer));
1178
1179 // nvrtc has a limit on the number of arguments allowed in a CUDA kernel.
1180 // The specific limit is a function of constant memory size, amount
1181 // available to pass arguments, and some implementation dependence. Select a
1182 // safe limit here.
1183 constexpr size_t subgraphArgLimit = 128;
1184 auto const nInputs = consumer->inputs().size() +
1185 consumer->outputs().size() + producer->inputs().size() +
1186 producer->outputs().size();
1187 REQ(nInputs <= subgraphArgLimit);
1188
1189 // Device checks
1190 if (consumer->kind() != aten::cat && producer->kind() != aten::cat) {
1191 // aten::cat needs a special handling because it takes a Tensor[] as its
1192 // input We deal with that in the code below.
1193 auto consumer_device = tensorexpr::pickDeviceType(consumer->inputs());
1194 REQ(consumer_device);
1195 auto producer_device = tensorexpr::pickDeviceType(producer->inputs());
1196 REQ(producer_device);
1197 REQ(*consumer_device == *producer_device);
1198 }
1199
1200 // Alias checks
1201 REQ(aliasDb_->couldMoveBeforeTopologically(producer, consumer));
1202
1203 // Ops that return aliases can only be folded if this is the only use.
1204 if (producer->kind() == aten::slice ||
1205 producer->kind() == aten::unsqueeze ||
1206 producer->kind() == prim::ConstantChunk) {
1207 for (auto& use : producer->output(0)->uses()) {
1208 REQ(use.user == consumer);
1209 }
1210 }
1211
1212 if (!consumer->hasAttribute(attr::Subgraph) &&
1213 consumer->kind() != prim::TensorExprGroup) {
1214 // Don't initiate a fusion group from prim::ListConstruct
1215 REQ(consumer->kind() != prim::ListConstruct);
1216 REQ(consumer->kind() != aten::slice);
1217 REQ(consumer->kind() != aten::unsqueeze);
1218 REQ(consumer->kind() != prim::ConstantChunk);
1219
1220 // Don't initiate a fusion group just for a constant operand
1221 REQ(producer->kind() != prim::Constant);
1222 }
1223
1224 if (producer->kind() == aten::cat) {
1225 REQ(producer->input(0)->node()->kind() == prim::ListConstruct);
1226 REQ(producer->input(0)->uses().size() == 1);
1227 REQ(producer->input(1)->node()->kind() == prim::Constant);
1228 auto const& listConstruct = producer->input(0)->node();
1229 // We're merging listconstruct->cat->consumer. cat is the producer here
1230 // and we cannot determine its device type - we should use device of the
1231 // listconstruct instead
1232 auto listconstruct_device =
1233 tensorexpr::pickDeviceType(listConstruct->inputs());
1234 auto consumer_device = tensorexpr::pickDeviceType(consumer->inputs());
1235 REQ(listconstruct_device);
1236 REQ(consumer_device);
1237 REQ(*listconstruct_device == *consumer_device);
1238 for (auto const& input : listConstruct->inputs()) {
1239 REQ(isFusableOnDevice(input->node()));
1240 }
1241 REQ((nInputs + listConstruct->inputs().size()) <= subgraphArgLimit);
1242 } else if (consumer->kind() == aten::cat) {
1243 REQ(consumer->input(0)->node()->kind() == prim::ListConstruct);
1244 REQ(consumer->input(0)->uses().size() == 1);
1245 REQ(consumer->input(1)->node()->kind() == prim::Constant);
1246 auto const& listConstruct = consumer->input(0)->node();
1247 // We're merging listconstruct->cat. cat is the consumer and listconstruct
1248 // is the producer. cat doesn't have its device type and thus the only
1249 // thing we should check is that listconstruct's device is well defined
1250 // (e.g. all its inputs has the same device).
1251 auto listconstruct_device =
1252 tensorexpr::pickDeviceType(listConstruct->inputs());
1253 REQ(listconstruct_device);
1254 REQ((nInputs + listConstruct->inputs().size()) <= subgraphArgLimit);
1255 } else {
1256 REQ(isFusableOnDevice(producer));
1257 }
1258
1259 return true;
1260 }
1261#undef REQ
1262
1263 void prepareFusionGroupAndGuardOutputs(Block* block) {
1264 std::vector<Node*> fusion_groups;
1265 for (Node* n : block->nodes()) {
1266 for (Block* b : n->blocks()) {
1267 prepareFusionGroupAndGuardOutputs(b);
1268 }
1269 if (n->kind() == prim::TensorExprGroup) {
1270 fusion_groups.push_back(n);
1271 }
1272 }
1273 for (Node* fusion_group : fusion_groups) {
1274 removeOutputsUsedOnlyInSize(fusion_group);
1275 insertTypeGuard(
1276 fusion_group,
1277 [](const TensorTypePtr& t) { return t; },
1278 prim::TypeCheck);
1279 }
1280 }
1281
1282 void generalizeFusionGroups(Block* block) {
1283 std::vector<Node*> fusion_groups;
1284 for (Node* n : block->nodes()) {
1285 for (Block* b : n->blocks()) {
1286 generalizeFusionGroups(b);
1287 }
1288 if (n->kind() == prim::TensorExprGroup) {
1289 fusion_groups.push_back(n);
1290 }
1291 }
1292 for (Node* fusion_group : fusion_groups) {
1293 removeOutputsUsedOnlyInSize(fusion_group);
1294 VLOG(1) << "GenerateGuard for fusion group: " << *fusion_group;
1295 if (!GenerateGuard(fusion_group, add_composed_op_)) {
1296 VLOG(1) << " Unfusing the fusion group because GenerateGuard failed"
1297 << std::endl;
1298 SubgraphUtils::unmergeSubgraph(fusion_group);
1299 }
1300 }
1301 }
1302
1303 // This function parses the option provided by the environment variable
1304 // "PYTORCH_TENSOREXPR_DONT_FUSE".
1305 // This variable allows users to disable fusion on a list of specified
1306 // operators that are separated by ':'. e.g.,
1307 // 'PYTORCH_TENSOREXPR_DONT_FUSE="clamp:mul:add"' disables fusion on
1308 // aten::clamp, aten::mul and aten::add.
1309 void parseTENotFuseOption() {
1310 const char* option = std::getenv("PYTORCH_TENSOREXPR_DONT_FUSE");
1311 std::stringstream in_ss;
1312 if (option) {
1313 in_ss << option;
1314 }
1315
1316 std::string line;
1317 while (std::getline(in_ss, line, ':')) {
1318 if (line.empty()) {
1319 continue;
1320 }
1321 operators_not_to_fuse.insert(c10::Symbol::aten(line));
1322 }
1323 }
1324
1325 std::shared_ptr<Graph> graph_;
1326 std::unique_ptr<AliasDb> aliasDb_ = nullptr;
1327
1328 std::set<NodeKind> operators_not_to_fuse;
1329 // Minimal size of a fusion group
1330 size_t min_group_size_;
1331 // compose Runtime Type Guard and Kernel in one op
1332 bool add_composed_op_;
1333 // generalize static shapes to dynamic shapes
1334 bool fuse_to_dynamic_shapes_;
1335};
1336
1337void FuseTensorExprs(
1338 std::shared_ptr<Graph>& graph,
1339 size_t min_group_size,
1340 bool add_composed_op,
1341 bool fuse_to_dynamic_shapes) {
1342 GRAPH_DUMP("Before TExprFuser: ", graph);
1343
1344 // Temporary change for Block code generation.
1345 if (tensorexpr::getTEGenerateBlockCode()) {
1346 min_group_size = 1;
1347 }
1348
1349 if (add_composed_op) {
1350 TORCH_INTERNAL_ASSERT(
1351 fuse_to_dynamic_shapes, "Fusing static shapes with composed op NYI");
1352 }
1353
1354 // Get rid of dead code so that we don't waste effort fusing it.
1355 EliminateDeadCode(graph);
1356
1357 TensorExprFuser fuser(
1358 graph, min_group_size, add_composed_op, fuse_to_dynamic_shapes);
1359 fuser.run();
1360
1361 EliminateCommonSubexpression(graph);
1362 EliminateDeadCode(graph);
1363
1364 GRAPH_DUMP("After TExprFuser: ", graph);
1365}
1366
1367Operation createTensorExprOp(const Node* node) {
1368 bool dynamic_shape_fusion_node =
1369 node->hasAttribute(attr::striding_inputs_desc);
1370 if (!dynamic_shape_fusion_node) {
1371 auto kernel =
1372 std::make_shared<tensorexpr::TensorExprKernel>(node->g(attr::Subgraph));
1373 return [kernel](Stack& stack) {
1374 RECORD_FUNCTION(kernel->getKernelName(), std::vector<c10::IValue>());
1375 kernel->run(stack);
1376 return 0;
1377 };
1378 }
1379
1380 // Handle the case when dynamic shape fusion is enabled.
1381 VLOG(1) << "Compiling a new kernel for " << *node;
1382 std::vector<int64_t> sym_shapes;
1383 if (node->hasAttribute(attr::symbolic_shape_inputs)) {
1384 sym_shapes = node->is(attr::symbolic_shape_inputs);
1385 }
1386 bool allow_stack_outputs = false;
1387 if (node->hasAttribute(attr::allow_stack_outputs)) {
1388 allow_stack_outputs = node->i(attr::allow_stack_outputs) == 1;
1389 }
1390
1391 std::unordered_map<c10::Symbol, tensorexpr::NNCLoweringFunction>
1392 custom_lowerings;
1393 auto subgraph = node->g(attr::Subgraph);
1394 IValue sym_strides = node->ival(attr::striding_inputs_desc);
1395
1396 // Striding Descriptor is serialized on the node as a vector of vector of
1397 // strings, translate back to StrideInput enum
1398 std::vector<std::vector<std::string>> sym_strides_strs =
1399 sym_strides.to<std::vector<std::vector<std::string>>>();
1400 std::vector<std::vector<StrideInput>> striding_inputs;
1401 for (const auto& vec : sym_strides_strs) {
1402 std::vector<StrideInput> input_desc;
1403 input_desc.reserve(vec.size());
1404 for (const std::string& str : vec) {
1405 input_desc.push_back(strideInputFromString(str));
1406 }
1407 striding_inputs.push_back(input_desc);
1408 }
1409 std::unordered_map<const Value*, std::vector<StrideInput>> stride_map;
1410 size_t index = 0;
1411 for (Value* v : subgraph->inputs()) {
1412 if (!v->type()->cast<TensorType>()) {
1413 continue;
1414 }
1415 stride_map[v] = striding_inputs[index];
1416 index++;
1417 }
1418 std::vector<std::string> output_desc =
1419 node->ival(attr::striding_outputs_desc).to<std::vector<std::string>>();
1420 for (size_t i = 0; i < subgraph->outputs().size(); ++i) {
1421 stride_map[subgraph->outputs().at(i)] = {
1422 strideInputFromString(output_desc.at(i))};
1423 }
1424
1425 std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
1426 std::make_shared<tensorexpr::TensorExprKernel>(
1427 subgraph,
1428 custom_lowerings,
1429 sym_shapes,
1430 /*pre_alloc*/ false,
1431 stride_map);
1432
1433 auto num_subgraph_inputs = subgraph->inputs().size();
1434 return [kernel, num_subgraph_inputs, allow_stack_outputs](Stack& stack) {
1435 RECORD_FUNCTION(kernel->getKernelName(), std::vector<c10::IValue>());
1436
1437 // Stack contents:
1438 // [<outputs>] <inputs>
1439 //
1440 // If the number of graph inputs is same as the stack size, then no
1441 // outputs are being passed in. Otherwise, output tensors are passed in
1442 // at the bottom of the stack. So, we call the appropriate run function
1443 // in TensorExprKernel.
1444 if (num_subgraph_inputs == stack.size() || !allow_stack_outputs) {
1445 kernel->run(stack);
1446 } else {
1447 kernel->runWithAllocatedOutputs(stack);
1448 }
1449 return 0;
1450 };
1451}
1452
1453RegisterOperators TensorExprOps({
1454 torch::jit::Operator(
1455 prim::TensorExprGroup,
1456 createTensorExprOp,
1457 AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
1458});
1459
1460} // namespace jit
1461} // namespace torch
1462