1 | #include <torch/csrc/jit/passes/eliminate_no_ops.h> |
---|---|
2 | |
3 | #include <torch/csrc/jit/jit_log.h> |
4 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
5 | #include <torch/csrc/jit/runtime/graph_iterator.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | |
10 | namespace { |
11 | |
12 | bool allInputsAreTensors(Node* node) { |
13 | for (const auto* value : node->inputs()) { |
14 | const auto& type = value->type(); |
15 | if (!type->castRaw<TensorType>()) { |
16 | return false; |
17 | } |
18 | } |
19 | return true; |
20 | } |
21 | |
22 | bool cannotOptimize(Node* node) { |
23 | const auto kind = node->kind(); |
24 | if (kind == aten::__is__ || kind == aten::__isnot__) { |
25 | return allInputsAreTensors(node); |
26 | } |
27 | return false; |
28 | } |
29 | |
30 | // Certain ops can make this optimization unsound. For example, |
31 | // consider the following graph: |
32 | // %y : Tensor = aten::detach(%x) |
33 | // %b : bool = aten::__is__(%y, %x) (= False) |
34 | // After remove detach, we would get |
35 | // %b : bool = aten::__is__(%x, %x) (= True!) |
36 | bool containsInvalidOp(std::shared_ptr<Graph>& graph) { |
37 | for (auto* node : graph->nodes()) { |
38 | if (cannotOptimize(node)) { |
39 | return true; |
40 | } |
41 | } |
42 | return false; |
43 | } |
44 | |
45 | } // namespace |
46 | |
47 | bool EliminateNoOps( |
48 | std::shared_ptr<Graph>& graph, |
49 | std::unordered_set<c10::Symbol> custom_ops) { |
50 | GRAPH_DUMP("Before EliminateNoOps: ", graph); |
51 | if (containsInvalidOp(graph)) { |
52 | return false; |
53 | } |
54 | // Ops here should be of the form x = f(x, ...) |
55 | std::unordered_set<c10::Symbol> no_ops{aten::detach}; |
56 | no_ops.insert(custom_ops.begin(), custom_ops.end()); |
57 | |
58 | bool changed = false; |
59 | |
60 | auto graph_it = DepthFirstGraphNodeIterator(graph); |
61 | for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) { |
62 | auto it = no_ops.find(node->kind()); |
63 | if (it == no_ops.end()) { |
64 | continue; |
65 | } |
66 | |
67 | changed = true; |
68 | node->output()->replaceAllUsesWith(node->input(0)); |
69 | } |
70 | |
71 | if (changed) { |
72 | EliminateDeadCode(graph); |
73 | } |
74 | |
75 | GRAPH_DUMP("After EliminateNoOps: ", graph); |
76 | return changed; |
77 | } |
78 | |
79 | } // namespace jit |
80 | } // namespace torch |
81 |