1#include <stack>
2
3#include <ATen/ATen.h>
4#include <torch/csrc/jit/api/module.h>
5#include <torch/csrc/jit/passes/constant_pooling.h>
6#include <torch/csrc/jit/passes/constant_propagation.h>
7#include <torch/csrc/jit/passes/quantization/helper.h>
8#include <torch/csrc/jit/passes/quantization/register_packed_params.h>
9
10namespace torch {
11namespace jit {
12
13namespace {
14bool isPrepackNode(Node* n) {
15 return (
16 n->kind() == Symbol::fromQualString("quantized::linear_prepack") ||
17 n->kind() == Symbol::fromQualString("quantized::conv1d_prepack") ||
18 n->kind() == Symbol::fromQualString("quantized::conv2d_prepack") ||
19 n->kind() == Symbol::fromQualString("quantized::conv3d_prepack") ||
20 n->kind() ==
21 Symbol::fromQualString("quantized::conv_transpose1d_prepack") ||
22 n->kind() ==
23 Symbol::fromQualString("quantized::conv_transpose2d_prepack"));
24}
25
26std::pair<Value*, std::string> findFPWeight(Node* prepack_node) {
27 TORCH_CHECK(isPrepackNode(prepack_node));
28 Node* n = nullptr;
29 n = prepack_node->input(0)->node();
30 bool is_quantize_node =
31 (n->kind() == Symbol::fromQualString("aten::quantize_per_tensor") ||
32 n->kind() == Symbol::fromQualString("aten::quantize_per_channel"));
33 TORCH_CHECK(
34 is_quantize_node,
35 "Input to prepack node must be output of weight quantization.");
36 // First input of quantize node is FP32 weight
37 n = n->input(0)->node();
38 bool is_getattr_node = (n->kind() == prim::GetAttr);
39 if (is_getattr_node) {
40 return {n->input(0), n->s(attr::name)};
41 }
42 return {nullptr, "AttributeDoesNotExist"};
43}
44} // namespace
45
46std::string joinPaths(const std::vector<std::string>& paths) {
47 std::string path;
48 for (const auto& p : paths) {
49 path.append(p).append(".");
50 }
51 return path;
52}
53// Must run this pass after constant folding.
54std::unordered_set<std::string> RegisterPrePackParams(
55 Module& m,
56 const std::string& method_name,
57 const PrePackParamFilterFn& is_packed_param,
58 const std::string& attr_prefix) {
59 int64_t uid = 0; // int + method name gives unique identifier
60 auto graph = m.get_method(method_name).graph();
61 std::stack<Block*> blocks_to_visit;
62 std::unordered_set<Node*> nodes_to_delete;
63 blocks_to_visit.push(graph->block());
64 std::string attr_name_base =
65 attr_prefix + "_" + method_name + "_ondevice_ptq_packed_weight_";
66 std::unordered_set<std::string> packed_param_names;
67
68 while (!blocks_to_visit.empty()) {
69 Block* b = blocks_to_visit.top();
70 blocks_to_visit.pop();
71 for (Node* n : b->nodes()) {
72 if (is_packed_param(n)) {
73 WithInsertPoint ins(n->next());
74 Value* packed_param_value = n->output(0);
75 TORCH_CHECK(n->outputs().size() == 1, "Prepack ops have single output");
76 auto attr_name = attr_name_base + c10::to_string(uid++);
77 TORCH_CHECK(
78 packed_param_value->uses().size() == 1,
79 "Packed param must be used by exactly one op.");
80 auto use = packed_param_value->uses()[0];
81 while (m.hasattr(attr_name)) {
82 attr_name = attr_name_base + "_" + c10::to_string(uid++);
83 }
84 // Now register attribute for this packed param but dont set it to any
85 // value. No value because we dont know what the value is at this point.
86 // Only when we run on-device ptq workflow, e.g. run quantize_forward
87 // method, is when the linear_prepack op will be executed and at that
88 // point we will have the actual value for this attribute.
89 m.register_attribute(attr_name, n->output(0)->type(), IValue());
90 // In order to add the output of linear_prepack, we now have to do
91 // setAttr Thus when quantize_forward is actually called the attribute
92 // is appropriately set.
93 Node* set_attr = graph->createSetAttr(
94 graph->inputs()[0], attr_name, packed_param_value);
95 set_attr->insertAfter(n);
96 // Now let's add GetAttr for the same attribute.
97 // Why?
98 // Because eventually the method being modified will be cloned into
99 // quantize_forward and quantized_forward.
100 // quantize_forward will only have, for example, linear_prepack and
101 // SetAttr Thus when quantize_forward is run attributes on the module
102 // are set. Then in quantized_forward we will actually get
103 // packed_params, via GetAttr and supply it to, for example,
104 // dynamic_linear At the end quantize_forward will not have any ops like
105 // dynamic_linear and quantized_forward will not have any linear_prepack
106 // or SetAttr
107 Value* packed_param_attr =
108 graph->insertGetAttr(graph->inputs()[0], attr_name)
109 ->setType(n->output(0)->type());
110 // We must replace this specific usage and we cannot doe
111 // replaceAllUsesWith This is because we first had to insert SetAttr
112 // node. This also takes as input packed_param_value, similar to the
113 // actual op. But only the use of the actual op must be replaced by
114 // output of GetAttr. Input of SetAttr still must use the
115 // packed_param_value
116 use.user->replaceInput(use.offset, packed_param_attr);
117 // Record the name of the attribute so that we can delete the SetAttr
118 // for it
119 packed_param_names.insert(std::move(attr_name));
120
121 // Now make sure that original weight is reset such that the module
122 // does not have weight attribute set anymore
123 auto value_weight_names_pair = findFPWeight(n);
124 Value* v = value_weight_names_pair.first;
125 std::string weight_name = std::move(value_weight_names_pair.second);
126 auto empty_tensor =
127 at::empty({0}, at::TensorOptions().requires_grad(false));
128 Node* none_node = graph->create(prim::Constant);
129 none_node->t_(attr::value, empty_tensor);
130 // none_node->output()->setType(TensorType::create(at::kFloat,
131 // c10::kCPU, 1, false));
132 Node* set_attr_orig_weight =
133 graph->createSetAttr(v, weight_name, none_node->output());
134 set_attr_orig_weight->insertAfter(packed_param_attr->node());
135 none_node->insertBefore(set_attr_orig_weight);
136 auto* self = v->owningGraph()->inputs()[0];
137 std::vector<std::string> path = getModuleAccessPath(v, self);
138 packed_param_names.emplace(joinPaths(path));
139 }
140 for (Block* subblock : n->blocks()) {
141 blocks_to_visit.push(subblock);
142 }
143 }
144 }
145 return packed_param_names;
146}
147
148} // namespace jit
149} // namespace torch
150