1#include <stack>
2
3#include <torch/csrc/jit/api/module.h>
4#include <torch/csrc/jit/jit_log.h>
5#include <torch/csrc/jit/passes/constant_pooling.h>
6#include <torch/csrc/jit/passes/constant_propagation.h>
7#include <torch/csrc/jit/passes/hoist_conv_packed_params.h>
8#include <torch/csrc/jit/passes/quantization/helper.h>
9
10namespace torch {
11namespace jit {
12
13// Hoists packed params from a conv module to the parent module.
14// The benefit is that after this hoisting, the conv module
15// no longer holds anything and can be deleted, reducing model
16// size.
17//
18// Before (easy case):
19//
20// %1 = prim::GetAttr[name="conv1"][%self]
21// %2 = prim::GetAttr[name="_packed_params][%1]
22//
23// After (easy case):
24//
25// %2 = prim::GetAttr[name="{prefix}.conv1._packed_params"][%self]
26//
27// Before (generic case):
28//
29// %1 = prim::GetAttr[name="name1"][%self]
30// %2 = prim::GetAttr[name="name2"][%1]
31// ...
32// %n = prim::GetAttr[name="_packed_params][%n-1]
33//
34// After (generic case):
35//
36// %n =
37// prim::GetAttr[name="{prefix}.name1{...}.name(n-1)._packed_params"][%self]
38//
39void hoistConvPackedParams(
40 Module& rootModule,
41 Node* getConvPackedParamsNode,
42 const std::string& prefix,
43 int& nameUniqueCounter) {
44 auto method = rootModule.get_method("forward");
45 auto graph = method.graph();
46 Value* rootModuleAsValue = graph->inputs()[0];
47
48 // get a path from root module to conv module
49 Value* convModuleAsValue = getConvPackedParamsNode->inputs()[0];
50 std::vector<std::string> rootToConvPath =
51 getModuleAccessPath(convModuleAsValue, rootModuleAsValue);
52
53 // get a module object representing the conv
54 Module convModule = findChildModule(rootModule, rootToConvPath);
55
56 // get the packed params value
57 c10::IValue packedParams = convModule.attr("_packed_params");
58
59 // create the new name
60
61 std::string suffix = "";
62 for (const auto& attrName : rootToConvPath) {
63 suffix += attrName + ".";
64 }
65 std::string newNameBase = prefix + "." + suffix + "_packed_params";
66 nameUniqueCounter++;
67 std::string newName = newNameBase + "." + c10::to_string(nameUniqueCounter);
68 while (rootModule.hasattr(newName)) {
69 nameUniqueCounter++;
70 newName = newNameBase + "." + c10::to_string(nameUniqueCounter);
71 }
72
73 // copy the packed params
74 rootModule.register_attribute(newName, packedParams.type(), packedParams);
75
76 // change target module to rootModule
77 getConvPackedParamsNode->replaceInput(0, rootModuleAsValue);
78
79 // change attribute name to new name
80 getConvPackedParamsNode->s_(Symbol::attr("name"), newName);
81}
82
83void HoistConvPackedParams(script::Module& m) {
84 auto method = m.get_method("forward");
85 auto graph = method.graph();
86
87 std::stack<Block*> blocks_to_visit;
88 blocks_to_visit.push(graph->block());
89 std::string attr_name_base = "_jit_pass_hoist_conv_packed_params";
90 // counter to ensure new attribute names are unique
91 int nameUniqueCounter = 0;
92
93 while (!blocks_to_visit.empty()) {
94 Block* b = blocks_to_visit.top();
95 blocks_to_visit.pop();
96
97 for (Node* n : b->nodes()) {
98 // make sure this node is fetching {foo}.{_packed_params}
99 bool isGetPackedParamsNode =
100 n->kind() == prim::GetAttr && n->s(attr::name) == "_packed_params";
101 if (isGetPackedParamsNode) {
102 // make sure the foo in {foo}.{_packed_params} is a quantized conv
103 c10::optional<std::string> moduleName = getModuleName(n->inputs()[0]);
104 bool moduleNameIsQuantizedConv = moduleName.has_value() &&
105 (moduleName.value() ==
106 "__torch__.torch.ao.nn.quantized.modules.conv.Conv1d" ||
107 moduleName.value() ==
108 "__torch__.torch.ao.nn.quantized.modules.conv.Conv2d" ||
109 moduleName.value() ==
110 "__torch__.torch.ao.nn.quantized.modules.conv.Conv3d" ||
111 moduleName.value() ==
112 "__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU1d" ||
113 moduleName.value() ==
114 "__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d" ||
115 moduleName.value() ==
116 "__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU3d" ||
117 // BC Stuff
118 moduleName.value() ==
119 "__torch__.torch.nn.quantized.modules.conv.Conv1d" ||
120 moduleName.value() ==
121 "__torch__.torch.nn.quantized.modules.conv.Conv2d" ||
122 moduleName.value() ==
123 "__torch__.torch.nn.quantized.modules.conv.Conv3d");
124
125 if (moduleNameIsQuantizedConv) {
126 GRAPH_UPDATE("Hoisting ", *n, " to root module.");
127 hoistConvPackedParams(m, n, attr_name_base, nameUniqueCounter);
128 }
129 }
130
131 for (Block* subblock : n->blocks()) {
132 blocks_to_visit.push(subblock);
133 }
134
135 } // for
136
137 } // while
138}
139
140} // namespace jit
141} // namespace torch
142