1 | |
---|---|
2 | #include <torch/csrc/jit/passes/check_strict_fusion.h> |
3 | |
4 | #include <c10/util/Exception.h> |
5 | #include <torch/csrc/jit/frontend/error_report.h> |
6 | #include <torch/csrc/jit/ir/ir.h> |
7 | #include <torch/csrc/jit/jit_log.h> |
8 | #include <torch/csrc/jit/passes/quantization/helper.h> |
9 | #include <torch/csrc/jit/runtime/graph_iterator.h> |
10 | #include <unordered_map> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | |
15 | namespace { |
16 | |
17 | bool isStrictFusion(Value* value) { |
18 | const auto class_name = getModuleName(value); |
19 | return class_name.has_value() && |
20 | (*class_name == "__torch__.torch.jit.strict_fusion"); |
21 | } |
22 | |
23 | } // namespace |
24 | |
25 | bool fusionGuardCheck(Symbol k) { |
26 | return k == Symbol::prim("TensorExprDynamicGuard") || k == prim::TypeCheck || |
27 | k == prim::CudaFusionGuard || k == prim::RequiresGradCheck; |
28 | } |
29 | |
30 | std::unordered_set<Node*> collectValuesUsedInGuard( |
31 | Node* guarding_if, |
32 | Node* enter_node) { |
33 | // DFS to collect |
34 | std::unordered_set<Node*> visited_nodes; |
35 | std::vector<Node*> queue = {guarding_if}; |
36 | |
37 | while (!queue.empty()) { |
38 | Node* curr = queue[queue.size() - 1]; |
39 | queue.pop_back(); |
40 | visited_nodes.insert(curr); |
41 | // these nodes directly test Tensor inputs, and are not part of additional |
42 | // guards inserted |
43 | if (fusionGuardCheck(curr->kind())) { |
44 | continue; |
45 | } |
46 | for (Value* v : curr->inputs()) { |
47 | Node* inp_node = v->node(); |
48 | if (inp_node->isBefore(enter_node) || |
49 | inp_node->owningBlock() != enter_node->owningBlock()) { |
50 | continue; |
51 | } |
52 | if (visited_nodes.count(inp_node)) { |
53 | continue; |
54 | } |
55 | queue.push_back(inp_node); |
56 | } |
57 | } |
58 | return visited_nodes; |
59 | } |
60 | |
61 | void checkForUnfusedOps(Node* enter_node) { |
62 | std::vector<Node*> unsupported_nodes; |
63 | std::vector<Node*> guarding_ifs; // if multiple, we will throw |
64 | for (Node* node = enter_node->next(); node->kind() != prim::Exit; |
65 | node = node->next()) { |
66 | if (node->kind() == prim::If && |
67 | fusionGuardCheck(node->input()->node()->kind())) { |
68 | guarding_ifs.push_back(node); |
69 | continue; |
70 | } |
71 | unsupported_nodes.push_back(node); |
72 | } |
73 | |
74 | if (guarding_ifs.size() > 1) { |
75 | std::stringstream ss; |
76 | ss << "Found multiple fusions: \n"; |
77 | for (Node* n : guarding_ifs) { |
78 | ss << *n << "\n"; |
79 | } |
80 | throw ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str(); |
81 | } |
82 | |
83 | // NVFuser/autodiff/nnc all insert a number of guards, see |
84 | // `CudaFusionViewGuard Example Graph` |
85 | // to check for unfused nodes, look at node's whose outputs |
86 | // are not depended on by the fusion guard |
87 | // restrict search for all values after the first |
88 | // node in the prim::Enter block |
89 | |
90 | std::unordered_set<Node*> guarding_check_nodes; |
91 | if (guarding_ifs.size() == 1) { |
92 | guarding_check_nodes = |
93 | collectValuesUsedInGuard(guarding_ifs[0], enter_node); |
94 | } |
95 | std::vector<Node*> unfused_nodes_not_used_in_guard; |
96 | for (Node* unfused : unsupported_nodes) { |
97 | if (!guarding_check_nodes.count(unfused)) { |
98 | unfused_nodes_not_used_in_guard.push_back(unfused); |
99 | } |
100 | } |
101 | if (!unfused_nodes_not_used_in_guard.empty()) { |
102 | std::stringstream ss; |
103 | ss << "Found unfused operators: \n"; |
104 | for (Node* unfused : unfused_nodes_not_used_in_guard) { |
105 | ss << "\t"; |
106 | if (unfused->maybeSchema()) { |
107 | ss << unfused->schema(); |
108 | } else { |
109 | unfused->kind().toDisplayString(); |
110 | } |
111 | ss << "\n"; |
112 | } |
113 | auto range = enter_node->input()->node()->sourceRange(); |
114 | throw ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str(); |
115 | } |
116 | } |
117 | |
118 | void CheckStrictFusion(std::shared_ptr<Graph>& graph) { |
119 | DepthFirstGraphNodeIterator it(graph); |
120 | Node* n = nullptr; |
121 | while ((n = it.next()) != nullptr) { |
122 | if (n->kind() == prim::Enter && isStrictFusion(n->input())) { |
123 | checkForUnfusedOps(n); |
124 | } |
125 | } |
126 | |
127 | // TODO: remove context manager after checks |
128 | // TODO: improve control flow not taken, right now always errors |
129 | } |
130 | |
131 | } // namespace jit |
132 | } // namespace torch |
133 |