1 | #include <c10/util/irange.h> |
2 | #include <torch/csrc/jit/jit_log.h> |
3 | #include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h> |
4 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | |
9 | struct ChunkOutput { |
10 | ChunkOutput(Value* v, size_t o) : val(v), offset(o){}; |
11 | Value* val; |
12 | size_t offset; |
13 | }; |
14 | |
15 | static c10::optional<std::vector<ChunkOutput>> getChunkOutputs(Node* chunk) { |
16 | std::vector<ChunkOutput> outputs; |
17 | for (auto list_use : chunk->output()->uses()) { |
18 | if (list_use.user->matches( |
19 | "aten::select(t[] list, int idx) -> t" , attr::idx) && |
20 | list_use.user->output()->type()->cast<TensorType>()) { |
21 | outputs.emplace_back( |
22 | list_use.user->output(), |
23 | list_use.user->get<int64_t>(attr::idx).value()); |
24 | } else if (list_use.user->kind() == prim::ListUnpack) { |
25 | // This sometimes happens if the sizes can't be evenly divided by the |
26 | // number of chunks |
27 | if (static_cast<int64_t>(list_use.user->outputs().size()) != |
28 | chunk->get<int64_t>(attr::chunks).value()) { |
29 | return c10::nullopt; |
30 | } |
31 | auto unpack_outputs = list_use.user->outputs(); |
32 | for (const auto i : c10::irange(unpack_outputs.size())) { |
33 | outputs.emplace_back(unpack_outputs[i], i); |
34 | } |
35 | } else { |
36 | return c10::nullopt; |
37 | } |
38 | } |
39 | return outputs; |
40 | } |
41 | |
42 | static void CanonicalizeOps(Block* block) { |
43 | for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end; |
44 | ++it) { |
45 | for (auto sub : it->blocks()) |
46 | CanonicalizeOps(sub); |
47 | if (it->matches( |
48 | "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor" ) || |
49 | it->matches( |
50 | "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor" ) || |
51 | it->matches("aten::mul(Tensor self, Tensor other) -> Tensor" ) || |
52 | it->matches("aten::div(Tensor self, Tensor other) -> Tensor" )) { |
53 | // Replace rank 0 Tensor constants with scalar constants. |
54 | if (auto other = it->get<at::Tensor>(attr::other)) { |
55 | if (other->dim() == 0) { |
56 | WithInsertPoint insert_guard{*it}; |
57 | auto graph = it->owningGraph(); |
58 | auto new_other = graph->insertConstant(other->item()); |
59 | std::vector<Value*> inputs = it->inputs().vec(); |
60 | inputs.at(1) = new_other; |
61 | Value* new_output = |
62 | graph->insertNode(graph->create(it->kind(), inputs))->output(); |
63 | new_output->node()->copyMetadata(*it); |
64 | new_output->copyMetadata(it->output()); |
65 | it->output()->replaceAllUsesWith(new_output); |
66 | } |
67 | } |
68 | } else if (it->matches( |
69 | "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]" , |
70 | /*const_inputs=*/{attr::chunks, attr::dim})) { |
71 | // Replace aten::chunk (which returns a list) with ConstantChunk with the |
72 | // outputs unpacked. |
73 | if (auto orig_outputs = getChunkOutputs(*it)) { |
74 | WithInsertPoint guard(*it); |
75 | auto* self = it->namedInput(attr::self); |
76 | auto* graph = it->owningGraph(); |
77 | const auto chunks = it->get<int64_t>(attr::chunks).value(); |
78 | const auto dim = it->get<int64_t>(attr::dim).value(); |
79 | auto* node = |
80 | graph->insertNode(graph->create(prim::ConstantChunk, chunks)); |
81 | node->addInput(self); |
82 | node->i_(attr::chunks, chunks)->i_(attr::dim, dim); |
83 | node->copyMetadata(*it); |
84 | for (const auto& orig_out : *orig_outputs) { |
85 | orig_out.val->replaceAllUsesWith(node->outputs()[orig_out.offset]); |
86 | node->outputs()[orig_out.offset]->setType(orig_out.val->type()); |
87 | } |
88 | } |
89 | } |
90 | } |
91 | } |
92 | |
93 | void CanonicalizeOps(const std::shared_ptr<Graph>& graph) { |
94 | CanonicalizeOps(graph->block()); |
95 | GRAPH_DUMP("After CanonicalizeOps: " , graph); |
96 | EliminateDeadCode(graph); |
97 | } |
98 | |
99 | } // namespace jit |
100 | } // namespace torch |
101 | |