1#pragma once
2
3#include <c10/util/variant.h>
4#include <torch/csrc/Export.h>
5#include <torch/csrc/jit/ir/ir.h>
6#include <unordered_map>
7#include <utility>
8
9namespace torch {
10namespace jit {
11
12// CAUTION NOT TO BE USED, STILL A WIP, NOT STABLE
13
14TORCH_API void PropagateShapesOnGraph(std::shared_ptr<Graph>& graph);
15
16// CAUTION NOT TO BE USED, STILL A WIP, NOT STABLE
17// From [beg, end) attempt to propagate shapes and
18// build up a graph that will compute all remaining symbolic
19// shapes in [beg, end) that can be executed before beg
20
21struct ShapeComputeGraphMapping {
22 ShapeComputeGraphMapping(
23 std::shared_ptr<Graph> partial_eval_shape_graph,
24 std::unordered_map<Value*, Value*>
25 enclosing_graph_value_to_shape_graph_input,
26 std::unordered_map<Value*, int64_t> graph_output_to_symbolic_shape_dim)
27 : partial_eval_shape_graph(std::move(partial_eval_shape_graph)),
28 enclosing_graph_value_to_shape_graph_input_(
29 std::move(enclosing_graph_value_to_shape_graph_input)),
30 graph_output_to_symbolic_shape_dim_(
31 std::move(graph_output_to_symbolic_shape_dim)){};
32
33 std::shared_ptr<Graph> partial_eval_shape_graph;
34 std::unordered_map<Value*, Value*>
35 enclosing_graph_value_to_shape_graph_input_;
36 std::unordered_map<Value*, int64_t> graph_output_to_symbolic_shape_dim_;
37};
38
39TORCH_API c10::optional<ShapeComputeGraphMapping>
40PropagateShapesAndBuildLargeShapeComputeGraph(
41 std::shared_ptr<Graph>& graph,
42 Node* beg,
43 Node* end);
44
45// don't insert complete tensor shapes in shape compute graphs and instead
46// rely on our partial evaluation pipeline to propagate information.
47// this is a good proxy for our ability to propagate non-complete shape
48// information.
49TORCH_API bool setSymbolicShapeAnalysisTestMode(bool value);
50TORCH_API bool symbolicShapeAnalysisTestModeEnabled();
51
52using SSAInput = c10::variant<IValue, c10::SymbolicShape>;
53TORCH_API c10::optional<std::vector<c10::SymbolicShape>>
54calculateSymbolicShapesOnOp(
55 const FunctionSchema* schema,
56 const std::vector<SSAInput>& inputs);
57} // namespace jit
58} // namespace torch
59