1#include <torch/csrc/jit/passes/replacement_of_old_operators.h>
2
3#include <c10/util/Exception.h>
4#include <caffe2/serialize/versions.h>
5#include <torch/csrc/jit/frontend/schema_matching.h>
6#include <torch/csrc/jit/ir/irparser.h>
7#include <torch/csrc/jit/operator_upgraders/upgraders.h>
8#include <torch/csrc/jit/operator_upgraders/utils.h>
9#include <torch/csrc/jit/operator_upgraders/version_map.h>
10#include <torch/csrc/jit/runtime/graph_iterator.h>
11#include <limits>
12#include <string>
13#include <unordered_map>
14#include <utility>
15
16namespace torch {
17namespace jit {
18
19struct OldOpsReplacerWithUpgraders {
20 OldOpsReplacerWithUpgraders(std::shared_ptr<Graph> graph)
21 : graph_(std::move(graph)) {}
22
23 void run() {
24 if (!graph_->get_op_version().has_value()) {
25 return;
26 }
27
28 auto current_version = graph_->get_op_version().value();
29 DepthFirstGraphNodeIterator graph_it(graph_);
30 Node* node = graph_it.next();
31 while (node) {
32 // load the schema name for this op
33 c10::optional<std::string> schema_name = c10::nullopt;
34 if (auto op_schema = node->maybeSchema()) {
35 schema_name = getFullSchemaName(*op_schema);
36 } else {
37 schema_name = node->getHistoricSchemaName();
38 }
39
40 if (schema_name.has_value()) {
41 // this implies there was a version bump because of this operator
42 auto version_entry =
43 get_operator_version_map().find(schema_name.value());
44 if (version_entry != get_operator_version_map().end()) {
45 const auto& entry = version_entry->second;
46 auto upgrader_entry = findUpgrader(entry, current_version);
47 if (!upgrader_entry.has_value()) {
48 if (!isOpSymbolCurrent(schema_name.value(), current_version)) {
49 TORCH_INTERNAL_ASSERT(
50 false,
51 "Upgrader must be present for ",
52 schema_name.value(),
53 ". The upgrader might have deprecated");
54 }
55 node = graph_it.next();
56 continue;
57 }
58 auto upgrader_entry_val = upgrader_entry.value();
59 auto upgrader_name = upgrader_entry_val.upgrader_name;
60 auto upgrader_graph_entry = dump_upgraders_map().find(upgrader_name);
61 TORCH_INTERNAL_ASSERT(
62 upgrader_graph_entry != dump_upgraders_map().end(),
63 "Corresponding upgrader graph for ",
64 upgrader_name,
65 " must exist.",
66 " This upgrader"
67 " might be deprecated.");
68
69 auto upgrader_graph = upgrader_graph_entry->second;
70 // inline the upgrader function body
71 WithInsertPoint guard(node);
72 auto new_outputs = insertGraph(
73 *node->owningGraph(), *upgrader_graph, node->inputs());
74 const auto& old_outputs = node->outputs();
75 TORCH_INTERNAL_ASSERT(new_outputs.size() == old_outputs.size());
76 for (const auto i : c10::irange(old_outputs.size())) {
77 TORCH_INTERNAL_ASSERT(
78 new_outputs[i]->type() == old_outputs[i]->type())
79 old_outputs[i]->replaceAllUsesWith(new_outputs[i]);
80 }
81 node->removeAllInputs();
82 node->destroy();
83 }
84 }
85 node = graph_it.next();
86 }
87
88 // now that we updated the graph, we want to bump the
89 // graph version too.
90 graph_->set_op_version(caffe2::serialize::kProducedFileFormatVersion);
91 }
92
93 std::shared_ptr<Graph> graph_;
94};
95
96TORCH_API void ReplaceOldOperatorsWithUpgraders(std::shared_ptr<Graph> graph) {
97 OldOpsReplacerWithUpgraders(std::move(graph)).run();
98}
99
100} // namespace jit
101} // namespace torch
102