1 | #include <torch/csrc/jit/passes/quantization/dedup_module_uses.h> |
2 | |
3 | #include <torch/csrc/jit/jit_log.h> |
4 | #include <torch/csrc/jit/passes/quantization/helper.h> |
5 | |
6 | #include <stack> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | namespace { |
11 | class ModuleUseDeduper { |
12 | public: |
13 | ModuleUseDeduper(Module& module) : module_(module) {} |
14 | void dedup() { |
15 | for (auto& method : module_.get_methods()) { |
16 | const auto& graph = method.graph(); |
17 | findModuleUses(graph.get()); |
18 | } |
19 | dedupModuleUses(); |
20 | } |
21 | |
22 | private: |
23 | // Analyze the code to record information represents |
24 | // uses of the module, which we'll use later to actually perform the dedup |
25 | // operation Please see the comments of member variables of the class for more |
26 | // information |
27 | void findModuleUses(Graph* graph) { |
28 | GRAPH_DUMP("Finding module uses for " , graph); |
29 | |
30 | std::stack<Block*> blocks_to_visit; |
31 | blocks_to_visit.push(graph->block()); |
32 | Value* self = graph->inputs()[0]; |
33 | while (!blocks_to_visit.empty()) { |
34 | Block* b = blocks_to_visit.top(); |
35 | blocks_to_visit.pop(); |
36 | for (Node* n : b->nodes()) { |
37 | for (Block* subblock : n->blocks()) { |
38 | blocks_to_visit.push(subblock); |
39 | } |
40 | if (n->kind() != prim::CallMethod) { |
41 | continue; |
42 | } |
43 | Value* instance = n->inputs()[0]; |
44 | // boundary_val is the value we get when we trace back |
45 | // the GetAttr access chain until we hit the input of graph |
46 | // or a node that is not prim::GetAttr |
47 | auto path = getModuleAccessPath(instance, self); |
48 | |
49 | // path.size() == 0 means we're calling a method |
50 | // on self, we don't need to dedup uses of self |
51 | if (path.empty()) { |
52 | continue; |
53 | } |
54 | value_to_path_map_[instance] = path; |
55 | auto m = findChildModule(module_, path); |
56 | // If we fail to insert the module to the unique_modules_ set, |
57 | // which means there are uses of this module before this point, |
58 | // we'll have to rewrite the use |
59 | if (!unique_modules_.insert(m._ivalue()).second) { |
60 | uses_to_rewrite_.push_back(instance); |
61 | GRAPH_DEBUG("Found use to rewrite: " , instance->debugName()); |
62 | } |
63 | } |
64 | } |
65 | } |
66 | |
67 | // Deduplicate module uses given the information we recorded before |
68 | void dedupModuleUses() { |
69 | for (Value* v : uses_to_rewrite_) { |
70 | const auto& path = value_to_path_map_.at(v); |
71 | const auto& m = findChildModule(module_, path); |
72 | // add a clone of the child module to the parent of the duplicated module |
73 | const auto& child_name = addChildModule(module_, m, path); |
74 | TORCH_INTERNAL_ASSERT(v->node()->kind() == prim::GetAttr); |
75 | // change the name in GetAttr call |
76 | auto original_name = v->node()->s(attr::name); |
77 | v->node()->s_(attr::name, child_name); |
78 | GRAPH_UPDATE( |
79 | "Module use dedup: changing use of original module " , |
80 | original_name, |
81 | " to " , |
82 | child_name); |
83 | } |
84 | } |
85 | |
86 | std::string addChildModule( |
87 | Module& module, |
88 | const Module& child_module, |
89 | const std::vector<std::string>& path) { |
90 | TORCH_INTERNAL_ASSERT( |
91 | !path.empty(), "path must have at least one element." ); |
92 | // Parent module of the leaf child module corresponding to |
93 | // the path |
94 | auto parent_of_leaf = findChildModule( |
95 | module, std::vector<std::string>(path.begin(), path.end() - 1)); |
96 | |
97 | // Original name of the child module |
98 | const std::string& original_name = path[path.size() - 1]; |
99 | int uid = 0; |
100 | std::string child_name = original_name + "_" + c10::to_string(uid++); |
101 | while (parent_of_leaf.hasattr(child_name)) { |
102 | child_name = original_name + "_" + c10::to_string(uid++); |
103 | } |
104 | parent_of_leaf.register_module(child_name, child_module.deepcopy()); |
105 | return child_name; |
106 | } |
107 | |
108 | Module module_; |
109 | // Map from value of module instance to the list of names of submodules |
110 | // starting from the top level module, e.g. ["sub1", "sub2", "relu"] |
111 | // Also this is a cache of calling `getModuleAccessPath` of the value |
112 | std::unordered_map<Value*, std::vector<std::string>> value_to_path_map_; |
113 | // Set of unique modules that are used in the graphs |
114 | std::unordered_set<ModulePtr> unique_modules_; |
115 | // Values that represent the module instance(the use of the module) |
116 | // that we'll need to rewrite as a use of a cloned module |
117 | // instance |
118 | std::vector<Value*> uses_to_rewrite_; |
119 | }; |
120 | |
121 | } // namespace |
122 | |
123 | void DedupModuleUses(Module& module) { |
124 | ModuleUseDeduper d(module); |
125 | d.dedup(); |
126 | } |
127 | |
128 | } // namespace jit |
129 | } // namespace torch |
130 | |