1#include <stack>
2
3#include <torch/csrc/jit/api/module.h>
4#include <torch/csrc/jit/passes/constant_pooling.h>
5#include <torch/csrc/jit/passes/constant_propagation.h>
6#include <torch/csrc/jit/passes/prepack_folding.h>
7
8namespace torch {
9namespace jit {
10
11// Must run this pass after constant folding.
12void PrePackingOpsFolder(
13 script::Module& m,
14 const PrePackingOpsFilterFn& is_foldable_op,
15 const std::string& attr_prefix) {
16 for (auto& method : m.get_methods()) {
17 int64_t uid = 0; // int + method name gives unique identifier
18 auto graph = method.graph();
19 std::stack<Block*> blocks_to_visit;
20 std::unordered_set<Node*> nodes_to_delete;
21 blocks_to_visit.push(graph->block());
22 std::string attr_name_base =
23 attr_prefix + "_" + method.name() + "._jit_pass_packed_weight_";
24 while (!blocks_to_visit.empty()) {
25 Block* b = blocks_to_visit.top();
26 blocks_to_visit.pop();
27 for (Node* n : b->nodes()) {
28 if (is_foldable_op(n)) {
29 auto optional_outputs = runNodeIfInputsAreConstant(n);
30 if (optional_outputs) {
31 auto outputs = optional_outputs.value();
32 TORCH_CHECK(outputs.size() == 1, "Prepack ops have single output");
33 auto attr_name = attr_name_base + c10::to_string(uid++);
34 TORCH_CHECK(
35 !(m.type()->findAttributeSlot(attr_name)),
36 "Attribute name ",
37 attr_name,
38 " already exist in",
39 " module of type:",
40 m.type()->name()->qualifiedName(),
41 ". Please make sure that",
42 " FoldPrePackingOps is run at the top level module only.");
43 m.register_attribute(attr_name, n->output(0)->type(), outputs[0]);
44 Value* prepack_op_value = n->output(0);
45 WithInsertPoint ins(prepack_op_value->node());
46 Value* packed_weight_attr =
47 graph->insertGetAttr(graph->inputs()[0], attr_name)
48 ->setType(n->output(0)->type());
49 prepack_op_value->replaceAllUsesWith(packed_weight_attr);
50 nodes_to_delete.insert(n);
51 }
52 }
53 for (Block* subblock : n->blocks()) {
54 blocks_to_visit.push(subblock);
55 }
56 }
57 }
58 for (auto n : nodes_to_delete) {
59 n->removeAllInputs();
60 }
61 for (auto n : nodes_to_delete) {
62 n->destroy();
63 }
64 }
65}
66
67} // namespace jit
68} // namespace torch
69