1#include <torch/csrc/jit/passes/dead_code_elimination.h>
2
3#include <torch/csrc/jit/ir/alias_analysis.h>
4#include <torch/csrc/jit/ir/ir_views.h>
5#include <torch/csrc/jit/jit_log.h>
6#include <torch/csrc/utils/memory.h>
7
8namespace torch {
9namespace jit {
10
11void RemoveRedundantProfiles(Block* block, AliasDb& db) {
12 for (auto it = block->nodes().end()->reverseIterator();
13 it != block->nodes().begin();) {
14 Node* n = *it;
15 it++;
16
17 for (Block* b : n->blocks()) {
18 RemoveRedundantProfiles(b, db);
19 }
20
21 // we only check prim::profile and not prim::profile_ivalue bc profile
22 // is inserted on each use, while profile_ivalue is inserted on the def
23 if (n->kind() != prim::profile ||
24 n->input()->node()->kind() != prim::profile) {
25 continue;
26 }
27
28 Node* input_node = n->input()->node();
29 if (input_node->ty(attr::profiled_type) != n->ty(attr::profiled_type)) {
30 continue;
31 }
32
33 if (!db.moveBeforeTopologicallyValid(input_node, n)) {
34 continue;
35 }
36
37 n->output()->replaceAllUsesWith(n->input());
38 n->destroy();
39 }
40}
41
42void RemoveRedundantProfiles(std::shared_ptr<Graph>& graph) {
43 AliasDb db(graph);
44 RemoveRedundantProfiles(graph->block(), db);
45}
46
47} // namespace jit
48} // namespace torch
49