1 | #include <torch/csrc/jit/passes/insert_guards.h> |
---|---|
2 | #include <torch/csrc/jit/runtime/profiling_record.h> |
3 | #include <memory> |
4 | #include <unordered_set> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | |
9 | struct GuardInserter { |
10 | GuardInserter(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {} |
11 | |
12 | void run() { |
13 | insertGuards(graph_->block()); |
14 | ProfilingRecord::removeProfilingNodes(graph_->block()); |
15 | } |
16 | |
17 | private: |
18 | void insertGuards(Block* b) { |
19 | for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) { |
20 | auto n = *it; |
21 | if (n->kind() == prim::profile) { |
22 | auto pttp = n->ty(attr::profiled_type)->cast<TensorType>(); |
23 | if (pttp) { |
24 | auto guard = graph_->create(prim::Guard, {n->input()}, 1); |
25 | auto go = guard->output(); |
26 | go->setType(pttp); |
27 | guard->insertBefore(n); |
28 | n->output()->replaceAllUsesWith(go); |
29 | } else { |
30 | // we didn't go down this path i.e |
31 | // no profiling information is available |
32 | n->output()->replaceAllUsesWith(n->input()); |
33 | } |
34 | it.destroyCurrent(); |
35 | } else { |
36 | for (Block* ib : n->blocks()) { |
37 | insertGuards(ib); |
38 | } |
39 | } |
40 | } |
41 | } |
42 | |
43 | std::shared_ptr<Graph> graph_; |
44 | }; |
45 | |
46 | void InsertGuards(std::shared_ptr<Graph> graph) { |
47 | GuardInserter gi(std::move(graph)); |
48 | gi.run(); |
49 | } |
50 | |
51 | } // namespace jit |
52 | } // namespace torch |
53 |