1#include <torch/csrc/jit/passes/eliminate_no_ops.h>
2
3#include <torch/csrc/jit/jit_log.h>
4#include <torch/csrc/jit/passes/dead_code_elimination.h>
5#include <torch/csrc/jit/runtime/graph_iterator.h>
6
7namespace torch {
8namespace jit {
9
10namespace {
11
12bool allInputsAreTensors(Node* node) {
13 for (const auto* value : node->inputs()) {
14 const auto& type = value->type();
15 if (!type->castRaw<TensorType>()) {
16 return false;
17 }
18 }
19 return true;
20}
21
22bool cannotOptimize(Node* node) {
23 const auto kind = node->kind();
24 if (kind == aten::__is__ || kind == aten::__isnot__) {
25 return allInputsAreTensors(node);
26 }
27 return false;
28}
29
30// Certain ops can make this optimization unsound. For example,
31// consider the following graph:
32// %y : Tensor = aten::detach(%x)
33// %b : bool = aten::__is__(%y, %x) (= False)
34// After remove detach, we would get
35// %b : bool = aten::__is__(%x, %x) (= True!)
36bool containsInvalidOp(std::shared_ptr<Graph>& graph) {
37 for (auto* node : graph->nodes()) {
38 if (cannotOptimize(node)) {
39 return true;
40 }
41 }
42 return false;
43}
44
45} // namespace
46
47bool EliminateNoOps(
48 std::shared_ptr<Graph>& graph,
49 std::unordered_set<c10::Symbol> custom_ops) {
50 GRAPH_DUMP("Before EliminateNoOps: ", graph);
51 if (containsInvalidOp(graph)) {
52 return false;
53 }
54 // Ops here should be of the form x = f(x, ...)
55 std::unordered_set<c10::Symbol> no_ops{aten::detach};
56 no_ops.insert(custom_ops.begin(), custom_ops.end());
57
58 bool changed = false;
59
60 auto graph_it = DepthFirstGraphNodeIterator(graph);
61 for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) {
62 auto it = no_ops.find(node->kind());
63 if (it == no_ops.end()) {
64 continue;
65 }
66
67 changed = true;
68 node->output()->replaceAllUsesWith(node->input(0));
69 }
70
71 if (changed) {
72 EliminateDeadCode(graph);
73 }
74
75 GRAPH_DUMP("After EliminateNoOps: ", graph);
76 return changed;
77}
78
79} // namespace jit
80} // namespace torch
81