1#include <torch/csrc/jit/passes/refine_tuple_types.h>
2#include <torch/csrc/jit/runtime/graph_iterator.h>
3
4#include <ATen/core/type_factory.h>
5
6#include <utility>
7
8namespace torch {
9namespace jit {
10
11namespace {
12static void VisitTupleNode(Node* node) {
13 TORCH_CHECK(
14 node->outputs().size() == 1, "Tuple must have exactly one output!");
15
16 Value* output = node->outputs()[0];
17 auto tuple_type = output->type()->expectRef<TupleType>();
18
19 TORCH_CHECK(
20 tuple_type.containedTypes().size() == node->inputs().size(),
21 "Number of contained types does not match number of inputs!");
22
23 // Extract updated types from input values.
24 std::vector<c10::TypePtr> types;
25 for (const Value* input : node->inputs()) {
26 types.push_back(input->type());
27 }
28
29 // Construct new tuple type based on input types.
30 output->setType(tuple_type.withContained(std::move(types)));
31}
32} // anonymous namespace
33
34void RefineTupleTypes(std::shared_ptr<Graph>& graph) {
35 DepthFirstGraphNodeIterator it(graph);
36 for (auto* node = it.next(); node != nullptr; node = it.next()) {
37 if (node->kind() == prim::TupleConstruct) {
38 VisitTupleNode(node);
39 }
40 }
41}
42
43} // namespace jit
44} // namespace torch
45