1#pragma once
2
3#include <c10/util/ArrayRef.h>
4#include <torch/csrc/jit/api/function_impl.h>
5#include <torch/csrc/jit/ir/ir.h>
6#include <torch/csrc/lazy/backend/lowering_context.h>
7#include <torch/csrc/lazy/core/ir.h>
8#include <torch/csrc/lazy/core/shape.h>
9#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
10
11namespace torch {
12namespace lazy {
13
14using TSOpVector = std::vector<torch::jit::Value*>;
15
16class TORCH_API TsNode : public lazy::Node {
17 public:
18 TsNode(
19 OpKind op,
20 OpList operands,
21 std::vector<Shape>&& shapes,
22 size_t num_outputs,
23 hash_t hash_seed = kHashSeed);
24
25 TsNode(
26 OpKind op,
27 OpList operands,
28 const std::function<Shape()>& shape_fn,
29 size_t num_outputs,
30 hash_t hash_seed = kHashSeed);
31
32 TsNode(
33 OpKind op,
34 OpList operands,
35 size_t num_outputs,
36 hash_t hash_seed = kHashSeed);
37
38 TsNode(
39 OpKind op,
40 Shape shape,
41 size_t num_outputs,
42 hash_t hash_seed = kHashSeed);
43
44 ~TsNode() override = default;
45
46 hash_t hash() const override;
47
48 hash_t shapeHash() const override;
49
50 const std::string getPythonStacktrace() const;
51
52 // Lower is a backend-specific method since it returns a backend specific
53 // type. hence, it is convenient to define it differently per-backend rather
54 // than at Node API
55 virtual TSOpVector Lower(
56 std::shared_ptr<torch::jit::GraphFunction> function,
57 TSLoweringContext* loctx) const;
58
59 private:
60 // The hash of the dag WITH size info. Used for shape caching
61 hash_t shape_hash_;
62 // The hash of the dag used to look up the compiled graph by a hash
63 // in this case, we will use the dag hash WITHOUT size info if dynamic shape
64 // is enabled and use the dag hash WITH size info otherwise.
65 hash_t dag_hash_;
66};
67
68// Note: this OpKind is separate from ltc_ops.h since it would be a circular
69// import otherwise, I like leaving TensorList in this file, and I think most of
70// ltc_ops special cases will be deleted anyway
71const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list");
72
73// TensorList represents an at::TensorList which is a vector[Tensor] but is also
74// a first-class IValue and can be fed as a single input to a TS program. It is
75// much easier to handle TensorLists in Lazy Tensor code if they are represented
76// as a single Node so there can be more than one TensorList and more than one
77// Tensor side-by-side as operands to an op.
78//
79// Note: shape is undefined for TensorList. We assert in some places that
80// #shapes matches #outputs and this stems from
81// the fact that currently all IR nodes represent tensors (there is no
82// type system for this IR). Becuase of this, TensorList is a bit of a
83// hack.
84//
85// TODO(whc) once Shape() API is moved to Node base, also make it virtual, and
86// then implement it as NotImplemented for TensorList, also fixing the assertion
87// that would fail.
88struct TORCH_API TensorList : public TsNode {
89 static OpKind ClassOpKind() {
90 return tensor_list_opkind;
91 }
92
93 TensorList() = delete;
94 TensorList(OpList values);
95
96 bool CanBeReused(OpList values) const {
97 return operands() == std::vector<Output>(values.begin(), values.end());
98 }
99
100 TSOpVector Lower(
101 std::shared_ptr<torch::jit::GraphFunction> function,
102 TSLoweringContext* loctx) const override;
103};
104
105} // namespace lazy
106} // namespace torch
107