1#include <torch/csrc/jit/passes/quantization/dedup_module_uses.h>
2
3#include <torch/csrc/jit/jit_log.h>
4#include <torch/csrc/jit/passes/quantization/helper.h>
5
6#include <stack>
7
8namespace torch {
9namespace jit {
10namespace {
11class ModuleUseDeduper {
12 public:
13 ModuleUseDeduper(Module& module) : module_(module) {}
14 void dedup() {
15 for (auto& method : module_.get_methods()) {
16 const auto& graph = method.graph();
17 findModuleUses(graph.get());
18 }
19 dedupModuleUses();
20 }
21
22 private:
23 // Analyze the code to record information represents
24 // uses of the module, which we'll use later to actually perform the dedup
25 // operation Please see the comments of member variables of the class for more
26 // information
27 void findModuleUses(Graph* graph) {
28 GRAPH_DUMP("Finding module uses for ", graph);
29
30 std::stack<Block*> blocks_to_visit;
31 blocks_to_visit.push(graph->block());
32 Value* self = graph->inputs()[0];
33 while (!blocks_to_visit.empty()) {
34 Block* b = blocks_to_visit.top();
35 blocks_to_visit.pop();
36 for (Node* n : b->nodes()) {
37 for (Block* subblock : n->blocks()) {
38 blocks_to_visit.push(subblock);
39 }
40 if (n->kind() != prim::CallMethod) {
41 continue;
42 }
43 Value* instance = n->inputs()[0];
44 // boundary_val is the value we get when we trace back
45 // the GetAttr access chain until we hit the input of graph
46 // or a node that is not prim::GetAttr
47 auto path = getModuleAccessPath(instance, self);
48
49 // path.size() == 0 means we're calling a method
50 // on self, we don't need to dedup uses of self
51 if (path.empty()) {
52 continue;
53 }
54 value_to_path_map_[instance] = path;
55 auto m = findChildModule(module_, path);
56 // If we fail to insert the module to the unique_modules_ set,
57 // which means there are uses of this module before this point,
58 // we'll have to rewrite the use
59 if (!unique_modules_.insert(m._ivalue()).second) {
60 uses_to_rewrite_.push_back(instance);
61 GRAPH_DEBUG("Found use to rewrite: ", instance->debugName());
62 }
63 }
64 }
65 }
66
67 // Deduplicate module uses given the information we recorded before
68 void dedupModuleUses() {
69 for (Value* v : uses_to_rewrite_) {
70 const auto& path = value_to_path_map_.at(v);
71 const auto& m = findChildModule(module_, path);
72 // add a clone of the child module to the parent of the duplicated module
73 const auto& child_name = addChildModule(module_, m, path);
74 TORCH_INTERNAL_ASSERT(v->node()->kind() == prim::GetAttr);
75 // change the name in GetAttr call
76 auto original_name = v->node()->s(attr::name);
77 v->node()->s_(attr::name, child_name);
78 GRAPH_UPDATE(
79 "Module use dedup: changing use of original module ",
80 original_name,
81 " to ",
82 child_name);
83 }
84 }
85
86 std::string addChildModule(
87 Module& module,
88 const Module& child_module,
89 const std::vector<std::string>& path) {
90 TORCH_INTERNAL_ASSERT(
91 !path.empty(), "path must have at least one element.");
92 // Parent module of the leaf child module corresponding to
93 // the path
94 auto parent_of_leaf = findChildModule(
95 module, std::vector<std::string>(path.begin(), path.end() - 1));
96
97 // Original name of the child module
98 const std::string& original_name = path[path.size() - 1];
99 int uid = 0;
100 std::string child_name = original_name + "_" + c10::to_string(uid++);
101 while (parent_of_leaf.hasattr(child_name)) {
102 child_name = original_name + "_" + c10::to_string(uid++);
103 }
104 parent_of_leaf.register_module(child_name, child_module.deepcopy());
105 return child_name;
106 }
107
108 Module module_;
109 // Map from value of module instance to the list of names of submodules
110 // starting from the top level module, e.g. ["sub1", "sub2", "relu"]
111 // Also this is a cache of calling `getModuleAccessPath` of the value
112 std::unordered_map<Value*, std::vector<std::string>> value_to_path_map_;
113 // Set of unique modules that are used in the graphs
114 std::unordered_set<ModulePtr> unique_modules_;
115 // Values that represent the module instance(the use of the module)
116 // that we'll need to rewrite as a use of a cloned module
117 // instance
118 std::vector<Value*> uses_to_rewrite_;
119};
120
121} // namespace
122
123void DedupModuleUses(Module& module) {
124 ModuleUseDeduper d(module);
125 d.dedup();
126}
127
128} // namespace jit
129} // namespace torch
130