1#include <torch/csrc/jit/passes/clear_undefinedness.h>
2
3#include <torch/csrc/jit/jit_log.h>
4
5namespace torch {
6namespace jit {
7
8void clearUndefinedness(Value* o) {
9 if (o->type()->kind() == TensorType::Kind) {
10 o->setType(TensorType::get());
11 } else if (
12 o->type()->kind() == ListType::Kind &&
13 o->type()->expectRef<ListType>().getElementType()->kind() ==
14 TensorType::Kind) {
15 o->setType(ListType::create(TensorType::get()));
16 }
17}
18
19void clearUndefinedness(Block* block) {
20 for (auto n : block->nodes()) {
21 for (auto o : n->outputs()) {
22 clearUndefinedness(o);
23 }
24 for (auto ib : n->blocks()) {
25 clearUndefinedness(ib);
26 }
27 }
28}
29
30void ClearUndefinedness(const std::shared_ptr<Graph>& graph) {
31 for (auto i : graph->inputs()) {
32 clearUndefinedness(i);
33 }
34 clearUndefinedness(graph->block());
35 GRAPH_DUMP("After removeUndefinedness: ", graph);
36}
37
38} // namespace jit
39} // namespace torch
40