1 | #pragma once |
---|---|
2 | |
3 | #include <torch/csrc/Export.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | #include <memory> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | |
10 | struct Graph; |
11 | |
12 | struct propagation_error : std::exception {}; |
13 | |
14 | class PropertyPropBase { |
15 | // Used for both Shape Propagation and Dtype/Device Propagation |
16 | public: |
17 | explicit PropertyPropBase(std::shared_ptr<Graph> graph) |
18 | : graph_(std::move(graph)) {} |
19 | virtual ~PropertyPropBase() = default; |
20 | |
21 | void propagateBlock(Block* block, bool insert_expands = true); |
22 | // insert_expands is used for shape inference |
23 | |
24 | void processIf(Node* node); |
25 | void processLoop(Node* node); |
26 | |
27 | protected: |
28 | virtual void propagateNode(Node* node, bool insert_expands = true) = 0; |
29 | void setUnshapedType(Value* o); |
30 | void setUnshapedType(Node* node); |
31 | std::shared_ptr<Graph> graph_; |
32 | }; |
33 | |
34 | TORCH_API void EraseShapeInformation(const std::shared_ptr<Graph>& graph); |
35 | TORCH_API void PropagateInputShapes(const std::shared_ptr<Graph>& graph); |
36 | |
37 | TORCH_API bool mergeTypes( |
38 | ArrayRef<Value*> lhs, |
39 | ArrayRef<Value*> rhs, |
40 | ArrayRef<Value*> outputs); |
41 | |
42 | } // namespace jit |
43 | } // namespace torch |
44 |