1 | #include <stack> |
2 | |
3 | #include <torch/csrc/jit/api/module.h> |
4 | #include <torch/csrc/jit/jit_log.h> |
5 | #include <torch/csrc/jit/passes/constant_pooling.h> |
6 | #include <torch/csrc/jit/passes/constant_propagation.h> |
7 | #include <torch/csrc/jit/passes/hoist_conv_packed_params.h> |
8 | #include <torch/csrc/jit/passes/quantization/helper.h> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | |
13 | // Hoists packed params from a conv module to the parent module. |
14 | // The benefit is that after this hoisting, the conv module |
15 | // no longer holds anything and can be deleted, reducing model |
16 | // size. |
17 | // |
18 | // Before (easy case): |
19 | // |
20 | // %1 = prim::GetAttr[name="conv1"][%self] |
21 | // %2 = prim::GetAttr[name="_packed_params][%1] |
22 | // |
23 | // After (easy case): |
24 | // |
25 | // %2 = prim::GetAttr[name="{prefix}.conv1._packed_params"][%self] |
26 | // |
27 | // Before (generic case): |
28 | // |
29 | // %1 = prim::GetAttr[name="name1"][%self] |
30 | // %2 = prim::GetAttr[name="name2"][%1] |
31 | // ... |
32 | // %n = prim::GetAttr[name="_packed_params][%n-1] |
33 | // |
34 | // After (generic case): |
35 | // |
36 | // %n = |
37 | // prim::GetAttr[name="{prefix}.name1{...}.name(n-1)._packed_params"][%self] |
38 | // |
39 | void hoistConvPackedParams( |
40 | Module& rootModule, |
41 | Node* getConvPackedParamsNode, |
42 | const std::string& prefix, |
43 | int& nameUniqueCounter) { |
44 | auto method = rootModule.get_method("forward" ); |
45 | auto graph = method.graph(); |
46 | Value* rootModuleAsValue = graph->inputs()[0]; |
47 | |
48 | // get a path from root module to conv module |
49 | Value* convModuleAsValue = getConvPackedParamsNode->inputs()[0]; |
50 | std::vector<std::string> rootToConvPath = |
51 | getModuleAccessPath(convModuleAsValue, rootModuleAsValue); |
52 | |
53 | // get a module object representing the conv |
54 | Module convModule = findChildModule(rootModule, rootToConvPath); |
55 | |
56 | // get the packed params value |
57 | c10::IValue packedParams = convModule.attr("_packed_params" ); |
58 | |
59 | // create the new name |
60 | |
61 | std::string suffix = "" ; |
62 | for (const auto& attrName : rootToConvPath) { |
63 | suffix += attrName + "." ; |
64 | } |
65 | std::string newNameBase = prefix + "." + suffix + "_packed_params" ; |
66 | nameUniqueCounter++; |
67 | std::string newName = newNameBase + "." + c10::to_string(nameUniqueCounter); |
68 | while (rootModule.hasattr(newName)) { |
69 | nameUniqueCounter++; |
70 | newName = newNameBase + "." + c10::to_string(nameUniqueCounter); |
71 | } |
72 | |
73 | // copy the packed params |
74 | rootModule.register_attribute(newName, packedParams.type(), packedParams); |
75 | |
76 | // change target module to rootModule |
77 | getConvPackedParamsNode->replaceInput(0, rootModuleAsValue); |
78 | |
79 | // change attribute name to new name |
80 | getConvPackedParamsNode->s_(Symbol::attr("name" ), newName); |
81 | } |
82 | |
83 | void HoistConvPackedParams(script::Module& m) { |
84 | auto method = m.get_method("forward" ); |
85 | auto graph = method.graph(); |
86 | |
87 | std::stack<Block*> blocks_to_visit; |
88 | blocks_to_visit.push(graph->block()); |
89 | std::string attr_name_base = "_jit_pass_hoist_conv_packed_params" ; |
90 | // counter to ensure new attribute names are unique |
91 | int nameUniqueCounter = 0; |
92 | |
93 | while (!blocks_to_visit.empty()) { |
94 | Block* b = blocks_to_visit.top(); |
95 | blocks_to_visit.pop(); |
96 | |
97 | for (Node* n : b->nodes()) { |
98 | // make sure this node is fetching {foo}.{_packed_params} |
99 | bool isGetPackedParamsNode = |
100 | n->kind() == prim::GetAttr && n->s(attr::name) == "_packed_params" ; |
101 | if (isGetPackedParamsNode) { |
102 | // make sure the foo in {foo}.{_packed_params} is a quantized conv |
103 | c10::optional<std::string> moduleName = getModuleName(n->inputs()[0]); |
104 | bool moduleNameIsQuantizedConv = moduleName.has_value() && |
105 | (moduleName.value() == |
106 | "__torch__.torch.ao.nn.quantized.modules.conv.Conv1d" || |
107 | moduleName.value() == |
108 | "__torch__.torch.ao.nn.quantized.modules.conv.Conv2d" || |
109 | moduleName.value() == |
110 | "__torch__.torch.ao.nn.quantized.modules.conv.Conv3d" || |
111 | moduleName.value() == |
112 | "__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU1d" || |
113 | moduleName.value() == |
114 | "__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d" || |
115 | moduleName.value() == |
116 | "__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU3d" || |
117 | // BC Stuff |
118 | moduleName.value() == |
119 | "__torch__.torch.nn.quantized.modules.conv.Conv1d" || |
120 | moduleName.value() == |
121 | "__torch__.torch.nn.quantized.modules.conv.Conv2d" || |
122 | moduleName.value() == |
123 | "__torch__.torch.nn.quantized.modules.conv.Conv3d" ); |
124 | |
125 | if (moduleNameIsQuantizedConv) { |
126 | GRAPH_UPDATE("Hoisting " , *n, " to root module." ); |
127 | hoistConvPackedParams(m, n, attr_name_base, nameUniqueCounter); |
128 | } |
129 | } |
130 | |
131 | for (Block* subblock : n->blocks()) { |
132 | blocks_to_visit.push(subblock); |
133 | } |
134 | |
135 | } // for |
136 | |
137 | } // while |
138 | } |
139 | |
140 | } // namespace jit |
141 | } // namespace torch |
142 | |