1 | #include <torch/csrc/jit/passes/clear_undefinedness.h> |
---|---|
2 | |
3 | #include <torch/csrc/jit/jit_log.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | |
8 | void clearUndefinedness(Value* o) { |
9 | if (o->type()->kind() == TensorType::Kind) { |
10 | o->setType(TensorType::get()); |
11 | } else if ( |
12 | o->type()->kind() == ListType::Kind && |
13 | o->type()->expectRef<ListType>().getElementType()->kind() == |
14 | TensorType::Kind) { |
15 | o->setType(ListType::create(TensorType::get())); |
16 | } |
17 | } |
18 | |
19 | void clearUndefinedness(Block* block) { |
20 | for (auto n : block->nodes()) { |
21 | for (auto o : n->outputs()) { |
22 | clearUndefinedness(o); |
23 | } |
24 | for (auto ib : n->blocks()) { |
25 | clearUndefinedness(ib); |
26 | } |
27 | } |
28 | } |
29 | |
30 | void ClearUndefinedness(const std::shared_ptr<Graph>& graph) { |
31 | for (auto i : graph->inputs()) { |
32 | clearUndefinedness(i); |
33 | } |
34 | clearUndefinedness(graph->block()); |
35 | GRAPH_DUMP("After removeUndefinedness: ", graph); |
36 | } |
37 | |
38 | } // namespace jit |
39 | } // namespace torch |
40 |