1 | #include <torch/csrc/jit/passes/remove_expands.h> |
---|---|
2 | |
3 | namespace torch { |
4 | namespace jit { |
5 | |
6 | static void RemoveExpands(Block* block) { |
7 | for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end; |
8 | ++it) { |
9 | for (auto sub : it->blocks()) |
10 | RemoveExpands(sub); |
11 | |
12 | if (it->kind() == aten::expand && it->get<bool>(attr::implicit) == true) { |
13 | it->output()->replaceAllUsesWith(it->namedInput(attr::self)); |
14 | it.destroyCurrent(); |
15 | } |
16 | } |
17 | } |
18 | |
19 | void RemoveExpands(const std::shared_ptr<Graph>& graph) { |
20 | RemoveExpands(graph->block()); |
21 | } |
22 | |
23 | } // namespace jit |
24 | } // namespace torch |
25 |