1#include <c10/util/irange.h>
2#include <torch/csrc/jit/jit_log.h>
3#include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
4#include <torch/csrc/jit/passes/dead_code_elimination.h>
5
6namespace torch {
7namespace jit {
8
9struct ChunkOutput {
10 ChunkOutput(Value* v, size_t o) : val(v), offset(o){};
11 Value* val;
12 size_t offset;
13};
14
15static c10::optional<std::vector<ChunkOutput>> getChunkOutputs(Node* chunk) {
16 std::vector<ChunkOutput> outputs;
17 for (auto list_use : chunk->output()->uses()) {
18 if (list_use.user->matches(
19 "aten::select(t[] list, int idx) -> t", attr::idx) &&
20 list_use.user->output()->type()->cast<TensorType>()) {
21 outputs.emplace_back(
22 list_use.user->output(),
23 list_use.user->get<int64_t>(attr::idx).value());
24 } else if (list_use.user->kind() == prim::ListUnpack) {
25 // This sometimes happens if the sizes can't be evenly divided by the
26 // number of chunks
27 if (static_cast<int64_t>(list_use.user->outputs().size()) !=
28 chunk->get<int64_t>(attr::chunks).value()) {
29 return c10::nullopt;
30 }
31 auto unpack_outputs = list_use.user->outputs();
32 for (const auto i : c10::irange(unpack_outputs.size())) {
33 outputs.emplace_back(unpack_outputs[i], i);
34 }
35 } else {
36 return c10::nullopt;
37 }
38 }
39 return outputs;
40}
41
42static void CanonicalizeOps(Block* block) {
43 for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
44 ++it) {
45 for (auto sub : it->blocks())
46 CanonicalizeOps(sub);
47 if (it->matches(
48 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
49 it->matches(
50 "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
51 it->matches("aten::mul(Tensor self, Tensor other) -> Tensor") ||
52 it->matches("aten::div(Tensor self, Tensor other) -> Tensor")) {
53 // Replace rank 0 Tensor constants with scalar constants.
54 if (auto other = it->get<at::Tensor>(attr::other)) {
55 if (other->dim() == 0) {
56 WithInsertPoint insert_guard{*it};
57 auto graph = it->owningGraph();
58 auto new_other = graph->insertConstant(other->item());
59 std::vector<Value*> inputs = it->inputs().vec();
60 inputs.at(1) = new_other;
61 Value* new_output =
62 graph->insertNode(graph->create(it->kind(), inputs))->output();
63 new_output->node()->copyMetadata(*it);
64 new_output->copyMetadata(it->output());
65 it->output()->replaceAllUsesWith(new_output);
66 }
67 }
68 } else if (it->matches(
69 "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
70 /*const_inputs=*/{attr::chunks, attr::dim})) {
71 // Replace aten::chunk (which returns a list) with ConstantChunk with the
72 // outputs unpacked.
73 if (auto orig_outputs = getChunkOutputs(*it)) {
74 WithInsertPoint guard(*it);
75 auto* self = it->namedInput(attr::self);
76 auto* graph = it->owningGraph();
77 const auto chunks = it->get<int64_t>(attr::chunks).value();
78 const auto dim = it->get<int64_t>(attr::dim).value();
79 auto* node =
80 graph->insertNode(graph->create(prim::ConstantChunk, chunks));
81 node->addInput(self);
82 node->i_(attr::chunks, chunks)->i_(attr::dim, dim);
83 node->copyMetadata(*it);
84 for (const auto& orig_out : *orig_outputs) {
85 orig_out.val->replaceAllUsesWith(node->outputs()[orig_out.offset]);
86 node->outputs()[orig_out.offset]->setType(orig_out.val->type());
87 }
88 }
89 }
90 }
91}
92
93void CanonicalizeOps(const std::shared_ptr<Graph>& graph) {
94 CanonicalizeOps(graph->block());
95 GRAPH_DUMP("After CanonicalizeOps: ", graph);
96 EliminateDeadCode(graph);
97}
98
99} // namespace jit
100} // namespace torch
101