1#include <torch/csrc/jit/passes/insert_guards.h>
2#include <torch/csrc/jit/runtime/profiling_record.h>
3#include <memory>
4#include <unordered_set>
5
6namespace torch {
7namespace jit {
8
9struct GuardInserter {
10 GuardInserter(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {}
11
12 void run() {
13 insertGuards(graph_->block());
14 ProfilingRecord::removeProfilingNodes(graph_->block());
15 }
16
17 private:
18 void insertGuards(Block* b) {
19 for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
20 auto n = *it;
21 if (n->kind() == prim::profile) {
22 auto pttp = n->ty(attr::profiled_type)->cast<TensorType>();
23 if (pttp) {
24 auto guard = graph_->create(prim::Guard, {n->input()}, 1);
25 auto go = guard->output();
26 go->setType(pttp);
27 guard->insertBefore(n);
28 n->output()->replaceAllUsesWith(go);
29 } else {
30 // we didn't go down this path i.e
31 // no profiling information is available
32 n->output()->replaceAllUsesWith(n->input());
33 }
34 it.destroyCurrent();
35 } else {
36 for (Block* ib : n->blocks()) {
37 insertGuards(ib);
38 }
39 }
40 }
41 }
42
43 std::shared_ptr<Graph> graph_;
44};
45
46void InsertGuards(std::shared_ptr<Graph> graph) {
47 GuardInserter gi(std::move(graph));
48 gi.run();
49}
50
51} // namespace jit
52} // namespace torch
53