1 | #pragma once |
2 | |
3 | #include <c10/util/ArrayRef.h> |
4 | #include <torch/csrc/jit/api/function_impl.h> |
5 | #include <torch/csrc/jit/ir/ir.h> |
6 | #include <torch/csrc/lazy/backend/lowering_context.h> |
7 | #include <torch/csrc/lazy/core/ir.h> |
8 | #include <torch/csrc/lazy/core/shape.h> |
9 | #include <torch/csrc/lazy/ts_backend/ts_lowering_context.h> |
10 | |
11 | namespace torch { |
12 | namespace lazy { |
13 | |
14 | using TSOpVector = std::vector<torch::jit::Value*>; |
15 | |
16 | class TORCH_API TsNode : public lazy::Node { |
17 | public: |
18 | TsNode( |
19 | OpKind op, |
20 | OpList operands, |
21 | std::vector<Shape>&& shapes, |
22 | size_t num_outputs, |
23 | hash_t hash_seed = kHashSeed); |
24 | |
25 | TsNode( |
26 | OpKind op, |
27 | OpList operands, |
28 | const std::function<Shape()>& shape_fn, |
29 | size_t num_outputs, |
30 | hash_t hash_seed = kHashSeed); |
31 | |
32 | TsNode( |
33 | OpKind op, |
34 | OpList operands, |
35 | size_t num_outputs, |
36 | hash_t hash_seed = kHashSeed); |
37 | |
38 | TsNode( |
39 | OpKind op, |
40 | Shape shape, |
41 | size_t num_outputs, |
42 | hash_t hash_seed = kHashSeed); |
43 | |
44 | ~TsNode() override = default; |
45 | |
46 | hash_t hash() const override; |
47 | |
48 | hash_t shapeHash() const override; |
49 | |
50 | const std::string getPythonStacktrace() const; |
51 | |
52 | // Lower is a backend-specific method since it returns a backend specific |
53 | // type. hence, it is convenient to define it differently per-backend rather |
54 | // than at Node API |
55 | virtual TSOpVector Lower( |
56 | std::shared_ptr<torch::jit::GraphFunction> function, |
57 | TSLoweringContext* loctx) const; |
58 | |
59 | private: |
60 | // The hash of the dag WITH size info. Used for shape caching |
61 | hash_t shape_hash_; |
62 | // The hash of the dag used to look up the compiled graph by a hash |
63 | // in this case, we will use the dag hash WITHOUT size info if dynamic shape |
64 | // is enabled and use the dag hash WITH size info otherwise. |
65 | hash_t dag_hash_; |
66 | }; |
67 | |
68 | // Note: this OpKind is separate from ltc_ops.h since it would be a circular |
69 | // import otherwise, I like leaving TensorList in this file, and I think most of |
70 | // ltc_ops special cases will be deleted anyway |
71 | const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list" ); |
72 | |
73 | // TensorList represents an at::TensorList which is a vector[Tensor] but is also |
74 | // a first-class IValue and can be fed as a single input to a TS program. It is |
75 | // much easier to handle TensorLists in Lazy Tensor code if they are represented |
76 | // as a single Node so there can be more than one TensorList and more than one |
77 | // Tensor side-by-side as operands to an op. |
78 | // |
79 | // Note: shape is undefined for TensorList. We assert in some places that |
80 | // #shapes matches #outputs and this stems from |
81 | // the fact that currently all IR nodes represent tensors (there is no |
82 | // type system for this IR). Becuase of this, TensorList is a bit of a |
83 | // hack. |
84 | // |
85 | // TODO(whc) once Shape() API is moved to Node base, also make it virtual, and |
86 | // then implement it as NotImplemented for TensorList, also fixing the assertion |
87 | // that would fail. |
88 | struct TORCH_API TensorList : public TsNode { |
89 | static OpKind ClassOpKind() { |
90 | return tensor_list_opkind; |
91 | } |
92 | |
93 | TensorList() = delete; |
94 | TensorList(OpList values); |
95 | |
96 | bool CanBeReused(OpList values) const { |
97 | return operands() == std::vector<Output>(values.begin(), values.end()); |
98 | } |
99 | |
100 | TSOpVector Lower( |
101 | std::shared_ptr<torch::jit::GraphFunction> function, |
102 | TSLoweringContext* loctx) const override; |
103 | }; |
104 | |
105 | } // namespace lazy |
106 | } // namespace torch |
107 | |