1#include <torch/csrc/jit/passes/constant_pooling.h>
2
3#include <ATen/core/symbol.h>
4#include <torch/csrc/jit/ir/alias_analysis.h>
5#include <torch/csrc/jit/ir/ir.h>
6#include <torch/csrc/jit/ir/node_hashing.h>
7#include <unordered_set>
8
9namespace torch {
10namespace jit {
11
12namespace {
13
14// Very similar to the common subexpression elimination pass
15// Move all constants to the beginning of the graph, and deduplicate
16void ConstantPooling(
17 Block* block,
18 std::unordered_set<Node*, HashNode, EqualNode>& constants,
19 const AliasDb& aliasDb) {
20 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
21 auto node = *it;
22 // node may be moved to a different block so advance iterator now
23 ++it;
24 if (!node->blocks().empty()) {
25 // Traverse sub-blocks.
26 for (auto block : node->blocks()) {
27 ConstantPooling(block, constants, aliasDb);
28 }
29 continue;
30 }
31
32 if (node->kind() != prim::Constant) {
33 continue;
34 }
35
36 // Check whether the same constant already exists.
37 auto subit = constants.insert(node);
38 if (!subit.second) {
39 auto existing = *subit.first;
40
41 auto old_ivalue = toIValue(existing->output());
42 auto new_ivalue = toIValue(node->output());
43
44 // if both values are the same object, we do not need to worry about
45 // changing the aliasing relationship
46 bool same_identity =
47 (old_ivalue && new_ivalue && (old_ivalue->is(new_ivalue)));
48
49 if (!same_identity &&
50 !aliasDb.safeToChangeAliasingRelationship(
51 node->outputs(), existing->outputs())) {
52 continue;
53 }
54
55 // constant exists, replace the uses of node, and destroy it.
56 node->replaceAllUsesWith(existing);
57 node->destroy();
58 continue;
59 }
60
61 // Move the constant definition to the beginning of the graph.
62 auto first_node = node->owningGraph()->block()->nodes().front();
63 if (node != first_node)
64 node->moveBefore(first_node);
65 }
66}
67} // anonymous namespace
68
69void ConstantPooling(const std::shared_ptr<Graph>& graph) {
70 AliasDb aliasDb(graph);
71 std::unordered_set<Node*, HashNode, EqualNode> constants;
72 ConstantPooling(graph->block(), constants, aliasDb);
73}
74} // namespace jit
75} // namespace torch
76