1#pragma once
2
3#include <torch/csrc/Export.h>
4#include <torch/csrc/jit/ir/ir.h>
5#include <memory>
6
7namespace torch {
8namespace jit {
9
10struct Graph;
11
12struct propagation_error : std::exception {};
13
14class PropertyPropBase {
15 // Used for both Shape Propagation and Dtype/Device Propagation
16 public:
17 explicit PropertyPropBase(std::shared_ptr<Graph> graph)
18 : graph_(std::move(graph)) {}
19 virtual ~PropertyPropBase() = default;
20
21 void propagateBlock(Block* block, bool insert_expands = true);
22 // insert_expands is used for shape inference
23
24 void processIf(Node* node);
25 void processLoop(Node* node);
26
27 protected:
28 virtual void propagateNode(Node* node, bool insert_expands = true) = 0;
29 void setUnshapedType(Value* o);
30 void setUnshapedType(Node* node);
31 std::shared_ptr<Graph> graph_;
32};
33
34TORCH_API void EraseShapeInformation(const std::shared_ptr<Graph>& graph);
35TORCH_API void PropagateInputShapes(const std::shared_ptr<Graph>& graph);
36
37TORCH_API bool mergeTypes(
38 ArrayRef<Value*> lhs,
39 ArrayRef<Value*> rhs,
40 ArrayRef<Value*> outputs);
41
42} // namespace jit
43} // namespace torch
44