1 | #include <stack> |
2 | |
3 | #include <torch/csrc/jit/api/module.h> |
4 | #include <torch/csrc/jit/passes/constant_pooling.h> |
5 | #include <torch/csrc/jit/passes/constant_propagation.h> |
6 | #include <torch/csrc/jit/passes/prepack_folding.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | // Must run this pass after constant folding. |
12 | void PrePackingOpsFolder( |
13 | script::Module& m, |
14 | const PrePackingOpsFilterFn& is_foldable_op, |
15 | const std::string& attr_prefix) { |
16 | for (auto& method : m.get_methods()) { |
17 | int64_t uid = 0; // int + method name gives unique identifier |
18 | auto graph = method.graph(); |
19 | std::stack<Block*> blocks_to_visit; |
20 | std::unordered_set<Node*> nodes_to_delete; |
21 | blocks_to_visit.push(graph->block()); |
22 | std::string attr_name_base = |
23 | attr_prefix + "_" + method.name() + "._jit_pass_packed_weight_" ; |
24 | while (!blocks_to_visit.empty()) { |
25 | Block* b = blocks_to_visit.top(); |
26 | blocks_to_visit.pop(); |
27 | for (Node* n : b->nodes()) { |
28 | if (is_foldable_op(n)) { |
29 | auto optional_outputs = runNodeIfInputsAreConstant(n); |
30 | if (optional_outputs) { |
31 | auto outputs = optional_outputs.value(); |
32 | TORCH_CHECK(outputs.size() == 1, "Prepack ops have single output" ); |
33 | auto attr_name = attr_name_base + c10::to_string(uid++); |
34 | TORCH_CHECK( |
35 | !(m.type()->findAttributeSlot(attr_name)), |
36 | "Attribute name " , |
37 | attr_name, |
38 | " already exist in" , |
39 | " module of type:" , |
40 | m.type()->name()->qualifiedName(), |
41 | ". Please make sure that" , |
42 | " FoldPrePackingOps is run at the top level module only." ); |
43 | m.register_attribute(attr_name, n->output(0)->type(), outputs[0]); |
44 | Value* prepack_op_value = n->output(0); |
45 | WithInsertPoint ins(prepack_op_value->node()); |
46 | Value* packed_weight_attr = |
47 | graph->insertGetAttr(graph->inputs()[0], attr_name) |
48 | ->setType(n->output(0)->type()); |
49 | prepack_op_value->replaceAllUsesWith(packed_weight_attr); |
50 | nodes_to_delete.insert(n); |
51 | } |
52 | } |
53 | for (Block* subblock : n->blocks()) { |
54 | blocks_to_visit.push(subblock); |
55 | } |
56 | } |
57 | } |
58 | for (auto n : nodes_to_delete) { |
59 | n->removeAllInputs(); |
60 | } |
61 | for (auto n : nodes_to_delete) { |
62 | n->destroy(); |
63 | } |
64 | } |
65 | } |
66 | |
67 | } // namespace jit |
68 | } // namespace torch |
69 | |