1 | #include <stack> |
2 | |
3 | #include <ATen/ATen.h> |
4 | #include <torch/csrc/jit/api/module.h> |
5 | #include <torch/csrc/jit/passes/constant_pooling.h> |
6 | #include <torch/csrc/jit/passes/constant_propagation.h> |
7 | #include <torch/csrc/jit/passes/quantization/helper.h> |
8 | #include <torch/csrc/jit/passes/quantization/register_packed_params.h> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | |
13 | namespace { |
14 | bool isPrepackNode(Node* n) { |
15 | return ( |
16 | n->kind() == Symbol::fromQualString("quantized::linear_prepack" ) || |
17 | n->kind() == Symbol::fromQualString("quantized::conv1d_prepack" ) || |
18 | n->kind() == Symbol::fromQualString("quantized::conv2d_prepack" ) || |
19 | n->kind() == Symbol::fromQualString("quantized::conv3d_prepack" ) || |
20 | n->kind() == |
21 | Symbol::fromQualString("quantized::conv_transpose1d_prepack" ) || |
22 | n->kind() == |
23 | Symbol::fromQualString("quantized::conv_transpose2d_prepack" )); |
24 | } |
25 | |
26 | std::pair<Value*, std::string> findFPWeight(Node* prepack_node) { |
27 | TORCH_CHECK(isPrepackNode(prepack_node)); |
28 | Node* n = nullptr; |
29 | n = prepack_node->input(0)->node(); |
30 | bool is_quantize_node = |
31 | (n->kind() == Symbol::fromQualString("aten::quantize_per_tensor" ) || |
32 | n->kind() == Symbol::fromQualString("aten::quantize_per_channel" )); |
33 | TORCH_CHECK( |
34 | is_quantize_node, |
35 | "Input to prepack node must be output of weight quantization." ); |
36 | // First input of quantize node is FP32 weight |
37 | n = n->input(0)->node(); |
38 | bool is_getattr_node = (n->kind() == prim::GetAttr); |
39 | if (is_getattr_node) { |
40 | return {n->input(0), n->s(attr::name)}; |
41 | } |
42 | return {nullptr, "AttributeDoesNotExist" }; |
43 | } |
44 | } // namespace |
45 | |
46 | std::string joinPaths(const std::vector<std::string>& paths) { |
47 | std::string path; |
48 | for (const auto& p : paths) { |
49 | path.append(p).append("." ); |
50 | } |
51 | return path; |
52 | } |
53 | // Must run this pass after constant folding. |
54 | std::unordered_set<std::string> RegisterPrePackParams( |
55 | Module& m, |
56 | const std::string& method_name, |
57 | const PrePackParamFilterFn& is_packed_param, |
58 | const std::string& attr_prefix) { |
59 | int64_t uid = 0; // int + method name gives unique identifier |
60 | auto graph = m.get_method(method_name).graph(); |
61 | std::stack<Block*> blocks_to_visit; |
62 | std::unordered_set<Node*> nodes_to_delete; |
63 | blocks_to_visit.push(graph->block()); |
64 | std::string attr_name_base = |
65 | attr_prefix + "_" + method_name + "_ondevice_ptq_packed_weight_" ; |
66 | std::unordered_set<std::string> packed_param_names; |
67 | |
68 | while (!blocks_to_visit.empty()) { |
69 | Block* b = blocks_to_visit.top(); |
70 | blocks_to_visit.pop(); |
71 | for (Node* n : b->nodes()) { |
72 | if (is_packed_param(n)) { |
73 | WithInsertPoint ins(n->next()); |
74 | Value* packed_param_value = n->output(0); |
75 | TORCH_CHECK(n->outputs().size() == 1, "Prepack ops have single output" ); |
76 | auto attr_name = attr_name_base + c10::to_string(uid++); |
77 | TORCH_CHECK( |
78 | packed_param_value->uses().size() == 1, |
79 | "Packed param must be used by exactly one op." ); |
80 | auto use = packed_param_value->uses()[0]; |
81 | while (m.hasattr(attr_name)) { |
82 | attr_name = attr_name_base + "_" + c10::to_string(uid++); |
83 | } |
84 | // Now register attribute for this packed param but dont set it to any |
85 | // value. No value because we dont know what the value is at this point. |
86 | // Only when we run on-device ptq workflow, e.g. run quantize_forward |
87 | // method, is when the linear_prepack op will be executed and at that |
88 | // point we will have the actual value for this attribute. |
89 | m.register_attribute(attr_name, n->output(0)->type(), IValue()); |
90 | // In order to add the output of linear_prepack, we now have to do |
91 | // setAttr Thus when quantize_forward is actually called the attribute |
92 | // is appropriately set. |
93 | Node* set_attr = graph->createSetAttr( |
94 | graph->inputs()[0], attr_name, packed_param_value); |
95 | set_attr->insertAfter(n); |
96 | // Now let's add GetAttr for the same attribute. |
97 | // Why? |
98 | // Because eventually the method being modified will be cloned into |
99 | // quantize_forward and quantized_forward. |
100 | // quantize_forward will only have, for example, linear_prepack and |
101 | // SetAttr Thus when quantize_forward is run attributes on the module |
102 | // are set. Then in quantized_forward we will actually get |
103 | // packed_params, via GetAttr and supply it to, for example, |
104 | // dynamic_linear At the end quantize_forward will not have any ops like |
105 | // dynamic_linear and quantized_forward will not have any linear_prepack |
106 | // or SetAttr |
107 | Value* packed_param_attr = |
108 | graph->insertGetAttr(graph->inputs()[0], attr_name) |
109 | ->setType(n->output(0)->type()); |
110 | // We must replace this specific usage and we cannot doe |
111 | // replaceAllUsesWith This is because we first had to insert SetAttr |
112 | // node. This also takes as input packed_param_value, similar to the |
113 | // actual op. But only the use of the actual op must be replaced by |
114 | // output of GetAttr. Input of SetAttr still must use the |
115 | // packed_param_value |
116 | use.user->replaceInput(use.offset, packed_param_attr); |
117 | // Record the name of the attribute so that we can delete the SetAttr |
118 | // for it |
119 | packed_param_names.insert(std::move(attr_name)); |
120 | |
121 | // Now make sure that original weight is reset such that the module |
122 | // does not have weight attribute set anymore |
123 | auto value_weight_names_pair = findFPWeight(n); |
124 | Value* v = value_weight_names_pair.first; |
125 | std::string weight_name = std::move(value_weight_names_pair.second); |
126 | auto empty_tensor = |
127 | at::empty({0}, at::TensorOptions().requires_grad(false)); |
128 | Node* none_node = graph->create(prim::Constant); |
129 | none_node->t_(attr::value, empty_tensor); |
130 | // none_node->output()->setType(TensorType::create(at::kFloat, |
131 | // c10::kCPU, 1, false)); |
132 | Node* set_attr_orig_weight = |
133 | graph->createSetAttr(v, weight_name, none_node->output()); |
134 | set_attr_orig_weight->insertAfter(packed_param_attr->node()); |
135 | none_node->insertBefore(set_attr_orig_weight); |
136 | auto* self = v->owningGraph()->inputs()[0]; |
137 | std::vector<std::string> path = getModuleAccessPath(v, self); |
138 | packed_param_names.emplace(joinPaths(path)); |
139 | } |
140 | for (Block* subblock : n->blocks()) { |
141 | blocks_to_visit.push(subblock); |
142 | } |
143 | } |
144 | } |
145 | return packed_param_names; |
146 | } |
147 | |
148 | } // namespace jit |
149 | } // namespace torch |
150 | |