1 | #include <torch/csrc/jit/passes/clear_profiling.h> |
---|---|
2 | |
3 | #include <torch/csrc/jit/jit_log.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | |
8 | void unprofileGraphInputs(const std::shared_ptr<Graph>& graph) { |
9 | for (auto i : graph->inputs()) { |
10 | if (i->type()->isSubtypeOf(*TensorType::get())) { |
11 | i->setType(unshapedType(i->type())); |
12 | } |
13 | } |
14 | } |
15 | |
16 | void unprofileBlock(Block* start_block) { |
17 | std::vector<Block*> stack; |
18 | stack.push_back(start_block); |
19 | |
20 | while (!stack.empty()) { |
21 | Block* block = stack.back(); |
22 | stack.pop_back(); |
23 | |
24 | for (auto n : block->nodes()) { |
25 | for (auto o : n->outputs()) { |
26 | if (o->type()->isSubtypeOf(*TensorType::get())) { |
27 | o->setType(unshapedType(o->type())); |
28 | } |
29 | } |
30 | stack.insert(stack.end(), n->blocks().begin(), n->blocks().end()); |
31 | } |
32 | } |
33 | } |
34 | |
35 | // We need to make sure that passes that use profiling information |
36 | // use it **only after** guards validating it are inserted |
37 | // Ideally, we would run any pass that relies on profiling information |
38 | // after `InsertBailOuts`, however, practically, some passes |
39 | // (e.g. Peephole) useful to run both w/ and w/o profiling information |
40 | // so we could run them in `preoptimizeGraph` and |
41 | // in `runProfilingInsensitiveOptimizations` |
42 | void ClearProfilingInformation(const std::shared_ptr<Graph>& graph) { |
43 | unprofileGraphInputs(graph); |
44 | unprofileBlock(graph->block()); |
45 | GRAPH_DUMP("After ClearProfilingInformation: ", graph); |
46 | } |
47 | |
48 | } // namespace jit |
49 | } // namespace torch |
50 |