1 | #include <torch/csrc/jit/passes/refine_tuple_types.h> |
---|---|
2 | #include <torch/csrc/jit/runtime/graph_iterator.h> |
3 | |
4 | #include <ATen/core/type_factory.h> |
5 | |
6 | #include <utility> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | namespace { |
12 | static void VisitTupleNode(Node* node) { |
13 | TORCH_CHECK( |
14 | node->outputs().size() == 1, "Tuple must have exactly one output!"); |
15 | |
16 | Value* output = node->outputs()[0]; |
17 | auto tuple_type = output->type()->expectRef<TupleType>(); |
18 | |
19 | TORCH_CHECK( |
20 | tuple_type.containedTypes().size() == node->inputs().size(), |
21 | "Number of contained types does not match number of inputs!"); |
22 | |
23 | // Extract updated types from input values. |
24 | std::vector<c10::TypePtr> types; |
25 | for (const Value* input : node->inputs()) { |
26 | types.push_back(input->type()); |
27 | } |
28 | |
29 | // Construct new tuple type based on input types. |
30 | output->setType(tuple_type.withContained(std::move(types))); |
31 | } |
32 | } // anonymous namespace |
33 | |
34 | void RefineTupleTypes(std::shared_ptr<Graph>& graph) { |
35 | DepthFirstGraphNodeIterator it(graph); |
36 | for (auto* node = it.next(); node != nullptr; node = it.next()) { |
37 | if (node->kind() == prim::TupleConstruct) { |
38 | VisitTupleNode(node); |
39 | } |
40 | } |
41 | } |
42 | |
43 | } // namespace jit |
44 | } // namespace torch |
45 |