1 | #pragma once |
2 | |
3 | #include <c10/util/variant.h> |
4 | #include <torch/csrc/Export.h> |
5 | #include <torch/csrc/jit/ir/ir.h> |
6 | #include <unordered_map> |
7 | #include <utility> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | // CAUTION NOT TO BE USED, STILL A WIP, NOT STABLE |
13 | |
14 | TORCH_API void PropagateShapesOnGraph(std::shared_ptr<Graph>& graph); |
15 | |
16 | // CAUTION NOT TO BE USED, STILL A WIP, NOT STABLE |
17 | // From [beg, end) attempt to propagate shapes and |
18 | // build up a graph that will compute all remaining symbolic |
19 | // shapes in [beg, end) that can be executed before beg |
20 | |
21 | struct ShapeComputeGraphMapping { |
22 | ShapeComputeGraphMapping( |
23 | std::shared_ptr<Graph> partial_eval_shape_graph, |
24 | std::unordered_map<Value*, Value*> |
25 | enclosing_graph_value_to_shape_graph_input, |
26 | std::unordered_map<Value*, int64_t> graph_output_to_symbolic_shape_dim) |
27 | : partial_eval_shape_graph(std::move(partial_eval_shape_graph)), |
28 | enclosing_graph_value_to_shape_graph_input_( |
29 | std::move(enclosing_graph_value_to_shape_graph_input)), |
30 | graph_output_to_symbolic_shape_dim_( |
31 | std::move(graph_output_to_symbolic_shape_dim)){}; |
32 | |
33 | std::shared_ptr<Graph> partial_eval_shape_graph; |
34 | std::unordered_map<Value*, Value*> |
35 | enclosing_graph_value_to_shape_graph_input_; |
36 | std::unordered_map<Value*, int64_t> graph_output_to_symbolic_shape_dim_; |
37 | }; |
38 | |
39 | TORCH_API c10::optional<ShapeComputeGraphMapping> |
40 | PropagateShapesAndBuildLargeShapeComputeGraph( |
41 | std::shared_ptr<Graph>& graph, |
42 | Node* beg, |
43 | Node* end); |
44 | |
45 | // don't insert complete tensor shapes in shape compute graphs and instead |
46 | // rely on our partial evaluation pipeline to propagate information. |
47 | // this is a good proxy for our ability to propagate non-complete shape |
48 | // information. |
49 | TORCH_API bool setSymbolicShapeAnalysisTestMode(bool value); |
50 | TORCH_API bool symbolicShapeAnalysisTestModeEnabled(); |
51 | |
52 | using SSAInput = c10::variant<IValue, c10::SymbolicShape>; |
53 | TORCH_API c10::optional<std::vector<c10::SymbolicShape>> |
54 | calculateSymbolicShapesOnOp( |
55 | const FunctionSchema* schema, |
56 | const std::vector<SSAInput>& inputs); |
57 | } // namespace jit |
58 | } // namespace torch |
59 | |