1
2#include <torch/csrc/jit/passes/check_strict_fusion.h>
3
4#include <c10/util/Exception.h>
5#include <torch/csrc/jit/frontend/error_report.h>
6#include <torch/csrc/jit/ir/ir.h>
7#include <torch/csrc/jit/jit_log.h>
8#include <torch/csrc/jit/passes/quantization/helper.h>
9#include <torch/csrc/jit/runtime/graph_iterator.h>
10#include <unordered_map>
11
12namespace torch {
13namespace jit {
14
15namespace {
16
17bool isStrictFusion(Value* value) {
18 const auto class_name = getModuleName(value);
19 return class_name.has_value() &&
20 (*class_name == "__torch__.torch.jit.strict_fusion");
21}
22
23} // namespace
24
25bool fusionGuardCheck(Symbol k) {
26 return k == Symbol::prim("TensorExprDynamicGuard") || k == prim::TypeCheck ||
27 k == prim::CudaFusionGuard || k == prim::RequiresGradCheck;
28}
29
30std::unordered_set<Node*> collectValuesUsedInGuard(
31 Node* guarding_if,
32 Node* enter_node) {
33 // DFS to collect
34 std::unordered_set<Node*> visited_nodes;
35 std::vector<Node*> queue = {guarding_if};
36
37 while (!queue.empty()) {
38 Node* curr = queue[queue.size() - 1];
39 queue.pop_back();
40 visited_nodes.insert(curr);
41 // these nodes directly test Tensor inputs, and are not part of additional
42 // guards inserted
43 if (fusionGuardCheck(curr->kind())) {
44 continue;
45 }
46 for (Value* v : curr->inputs()) {
47 Node* inp_node = v->node();
48 if (inp_node->isBefore(enter_node) ||
49 inp_node->owningBlock() != enter_node->owningBlock()) {
50 continue;
51 }
52 if (visited_nodes.count(inp_node)) {
53 continue;
54 }
55 queue.push_back(inp_node);
56 }
57 }
58 return visited_nodes;
59}
60
61void checkForUnfusedOps(Node* enter_node) {
62 std::vector<Node*> unsupported_nodes;
63 std::vector<Node*> guarding_ifs; // if multiple, we will throw
64 for (Node* node = enter_node->next(); node->kind() != prim::Exit;
65 node = node->next()) {
66 if (node->kind() == prim::If &&
67 fusionGuardCheck(node->input()->node()->kind())) {
68 guarding_ifs.push_back(node);
69 continue;
70 }
71 unsupported_nodes.push_back(node);
72 }
73
74 if (guarding_ifs.size() > 1) {
75 std::stringstream ss;
76 ss << "Found multiple fusions: \n";
77 for (Node* n : guarding_ifs) {
78 ss << *n << "\n";
79 }
80 throw ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str();
81 }
82
83 // NVFuser/autodiff/nnc all insert a number of guards, see
84 // `CudaFusionViewGuard Example Graph`
85 // to check for unfused nodes, look at node's whose outputs
86 // are not depended on by the fusion guard
87 // restrict search for all values after the first
88 // node in the prim::Enter block
89
90 std::unordered_set<Node*> guarding_check_nodes;
91 if (guarding_ifs.size() == 1) {
92 guarding_check_nodes =
93 collectValuesUsedInGuard(guarding_ifs[0], enter_node);
94 }
95 std::vector<Node*> unfused_nodes_not_used_in_guard;
96 for (Node* unfused : unsupported_nodes) {
97 if (!guarding_check_nodes.count(unfused)) {
98 unfused_nodes_not_used_in_guard.push_back(unfused);
99 }
100 }
101 if (!unfused_nodes_not_used_in_guard.empty()) {
102 std::stringstream ss;
103 ss << "Found unfused operators: \n";
104 for (Node* unfused : unfused_nodes_not_used_in_guard) {
105 ss << "\t";
106 if (unfused->maybeSchema()) {
107 ss << unfused->schema();
108 } else {
109 unfused->kind().toDisplayString();
110 }
111 ss << "\n";
112 }
113 auto range = enter_node->input()->node()->sourceRange();
114 throw ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str();
115 }
116}
117
118void CheckStrictFusion(std::shared_ptr<Graph>& graph) {
119 DepthFirstGraphNodeIterator it(graph);
120 Node* n = nullptr;
121 while ((n = it.next()) != nullptr) {
122 if (n->kind() == prim::Enter && isStrictFusion(n->input())) {
123 checkForUnfusedOps(n);
124 }
125 }
126
127 // TODO: remove context manager after checks
128 // TODO: improve control flow not taken, right now always errors
129}
130
131} // namespace jit
132} // namespace torch
133