1 | #include <torch/csrc/jit/passes/constant_pooling.h> |
2 | |
3 | #include <ATen/core/symbol.h> |
4 | #include <torch/csrc/jit/ir/alias_analysis.h> |
5 | #include <torch/csrc/jit/ir/ir.h> |
6 | #include <torch/csrc/jit/ir/node_hashing.h> |
7 | #include <unordered_set> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | namespace { |
13 | |
14 | // Very similar to the common subexpression elimination pass |
15 | // Move all constants to the beginning of the graph, and deduplicate |
16 | void ConstantPooling( |
17 | Block* block, |
18 | std::unordered_set<Node*, HashNode, EqualNode>& constants, |
19 | const AliasDb& aliasDb) { |
20 | for (auto it = block->nodes().begin(); it != block->nodes().end();) { |
21 | auto node = *it; |
22 | // node may be moved to a different block so advance iterator now |
23 | ++it; |
24 | if (!node->blocks().empty()) { |
25 | // Traverse sub-blocks. |
26 | for (auto block : node->blocks()) { |
27 | ConstantPooling(block, constants, aliasDb); |
28 | } |
29 | continue; |
30 | } |
31 | |
32 | if (node->kind() != prim::Constant) { |
33 | continue; |
34 | } |
35 | |
36 | // Check whether the same constant already exists. |
37 | auto subit = constants.insert(node); |
38 | if (!subit.second) { |
39 | auto existing = *subit.first; |
40 | |
41 | auto old_ivalue = toIValue(existing->output()); |
42 | auto new_ivalue = toIValue(node->output()); |
43 | |
44 | // if both values are the same object, we do not need to worry about |
45 | // changing the aliasing relationship |
46 | bool same_identity = |
47 | (old_ivalue && new_ivalue && (old_ivalue->is(new_ivalue))); |
48 | |
49 | if (!same_identity && |
50 | !aliasDb.safeToChangeAliasingRelationship( |
51 | node->outputs(), existing->outputs())) { |
52 | continue; |
53 | } |
54 | |
55 | // constant exists, replace the uses of node, and destroy it. |
56 | node->replaceAllUsesWith(existing); |
57 | node->destroy(); |
58 | continue; |
59 | } |
60 | |
61 | // Move the constant definition to the beginning of the graph. |
62 | auto first_node = node->owningGraph()->block()->nodes().front(); |
63 | if (node != first_node) |
64 | node->moveBefore(first_node); |
65 | } |
66 | } |
67 | } // anonymous namespace |
68 | |
69 | void ConstantPooling(const std::shared_ptr<Graph>& graph) { |
70 | AliasDb aliasDb(graph); |
71 | std::unordered_set<Node*, HashNode, EqualNode> constants; |
72 | ConstantPooling(graph->block(), constants, aliasDb); |
73 | } |
74 | } // namespace jit |
75 | } // namespace torch |
76 | |