1 | #include <torch/csrc/jit/passes/replacement_of_old_operators.h> |
2 | |
3 | #include <c10/util/Exception.h> |
4 | #include <caffe2/serialize/versions.h> |
5 | #include <torch/csrc/jit/frontend/schema_matching.h> |
6 | #include <torch/csrc/jit/ir/irparser.h> |
7 | #include <torch/csrc/jit/operator_upgraders/upgraders.h> |
8 | #include <torch/csrc/jit/operator_upgraders/utils.h> |
9 | #include <torch/csrc/jit/operator_upgraders/version_map.h> |
10 | #include <torch/csrc/jit/runtime/graph_iterator.h> |
11 | #include <limits> |
12 | #include <string> |
13 | #include <unordered_map> |
14 | #include <utility> |
15 | |
16 | namespace torch { |
17 | namespace jit { |
18 | |
19 | struct OldOpsReplacerWithUpgraders { |
20 | OldOpsReplacerWithUpgraders(std::shared_ptr<Graph> graph) |
21 | : graph_(std::move(graph)) {} |
22 | |
23 | void run() { |
24 | if (!graph_->get_op_version().has_value()) { |
25 | return; |
26 | } |
27 | |
28 | auto current_version = graph_->get_op_version().value(); |
29 | DepthFirstGraphNodeIterator graph_it(graph_); |
30 | Node* node = graph_it.next(); |
31 | while (node) { |
32 | // load the schema name for this op |
33 | c10::optional<std::string> schema_name = c10::nullopt; |
34 | if (auto op_schema = node->maybeSchema()) { |
35 | schema_name = getFullSchemaName(*op_schema); |
36 | } else { |
37 | schema_name = node->getHistoricSchemaName(); |
38 | } |
39 | |
40 | if (schema_name.has_value()) { |
41 | // this implies there was a version bump because of this operator |
42 | auto version_entry = |
43 | get_operator_version_map().find(schema_name.value()); |
44 | if (version_entry != get_operator_version_map().end()) { |
45 | const auto& entry = version_entry->second; |
46 | auto upgrader_entry = findUpgrader(entry, current_version); |
47 | if (!upgrader_entry.has_value()) { |
48 | if (!isOpSymbolCurrent(schema_name.value(), current_version)) { |
49 | TORCH_INTERNAL_ASSERT( |
50 | false, |
51 | "Upgrader must be present for " , |
52 | schema_name.value(), |
53 | ". The upgrader might have deprecated" ); |
54 | } |
55 | node = graph_it.next(); |
56 | continue; |
57 | } |
58 | auto upgrader_entry_val = upgrader_entry.value(); |
59 | auto upgrader_name = upgrader_entry_val.upgrader_name; |
60 | auto upgrader_graph_entry = dump_upgraders_map().find(upgrader_name); |
61 | TORCH_INTERNAL_ASSERT( |
62 | upgrader_graph_entry != dump_upgraders_map().end(), |
63 | "Corresponding upgrader graph for " , |
64 | upgrader_name, |
65 | " must exist." , |
66 | " This upgrader" |
67 | " might be deprecated." ); |
68 | |
69 | auto upgrader_graph = upgrader_graph_entry->second; |
70 | // inline the upgrader function body |
71 | WithInsertPoint guard(node); |
72 | auto new_outputs = insertGraph( |
73 | *node->owningGraph(), *upgrader_graph, node->inputs()); |
74 | const auto& old_outputs = node->outputs(); |
75 | TORCH_INTERNAL_ASSERT(new_outputs.size() == old_outputs.size()); |
76 | for (const auto i : c10::irange(old_outputs.size())) { |
77 | TORCH_INTERNAL_ASSERT( |
78 | new_outputs[i]->type() == old_outputs[i]->type()) |
79 | old_outputs[i]->replaceAllUsesWith(new_outputs[i]); |
80 | } |
81 | node->removeAllInputs(); |
82 | node->destroy(); |
83 | } |
84 | } |
85 | node = graph_it.next(); |
86 | } |
87 | |
88 | // now that we updated the graph, we want to bump the |
89 | // graph version too. |
90 | graph_->set_op_version(caffe2::serialize::kProducedFileFormatVersion); |
91 | } |
92 | |
93 | std::shared_ptr<Graph> graph_; |
94 | }; |
95 | |
96 | TORCH_API void ReplaceOldOperatorsWithUpgraders(std::shared_ptr<Graph> graph) { |
97 | OldOpsReplacerWithUpgraders(std::move(graph)).run(); |
98 | } |
99 | |
100 | } // namespace jit |
101 | } // namespace torch |
102 | |