1#pragma once
2
3#include <c10/core/SymNodeImpl.h>
4#include <c10/util/intrusive_ptr.h>
5#include <torch/csrc/lazy/backend/backend_data.h>
6#include <torch/csrc/lazy/backend/backend_device.h>
7#include <torch/csrc/lazy/core/ir.h>
8#include <torch/csrc/lazy/core/util.h>
9
10namespace torch {
11namespace lazy {
12
13class TORCH_API SymNodeImpl : public c10::SymNodeImpl {
14 public:
15 SymNodeImpl(NodePtr ptr) : node_(std::move(ptr)){};
16 NodePtr node_;
17};
18
19class LazyTensor;
20using LazyTensorPtr = c10::intrusive_ptr<LazyTensor>;
21
22class TORCH_API LazyTensor : public c10::intrusive_ptr_target {
23 public:
24 // This is the core lazy tensor data structure where all the tensor data is
25 // held. The lazy tensor is nothing more than a shared pointer to a Data
26 // object.
27 struct Data {
28 Data(BackendDataPtr handle, BackendDevice device)
29 : handle(std::move(handle)),
30 device(std::move(device)),
31 unique_id(GetNextTensorId()) {}
32 Data(Value ir_value, BackendDevice device)
33 : ir_value(std::move(ir_value)),
34 device(std::move(device)),
35 unique_id(GetNextTensorId()) {}
36 Data(at::Tensor tensor_data, BackendDevice device)
37 : tensor_data(std::move(tensor_data)),
38 device(std::move(device)),
39 unique_id(GetNextTensorId()) {}
40 // TODO(alanwaketan): Remove this ctor. This is a
41 // temporary ctor to ease XLA LTC migration. It depends on
42 // XLA's Functionalization integration.
43 Data(BackendDevice device)
44 : device(std::move(device)), unique_id(GetNextTensorId()) {}
45
46 virtual ~Data();
47
48 BackendDataPtr handle;
49 Value ir_value;
50 c10::optional<at::Tensor> tensor_data;
51 const BackendDevice device;
52 const int64_t unique_id = 0;
53 size_t generation = 1;
54 };
55
56 static LazyTensorPtr Create(
57 const at::Tensor& tensor,
58 const BackendDevice& device);
59 static LazyTensorPtr Create(Value ir_value, const BackendDevice& device);
60 static LazyTensorPtr Create(BackendDataPtr handle);
61 static LazyTensorPtr Create(std::shared_ptr<Data> data);
62
63 // The default ctor previously created a null LazyTensor (one with no 'data'
64 // obj). Creating a null LazyTensor is no longer possible, since the same can
65 // be achieved by creating a null LazyTensorPtr and it is way too confusing to
66 // have to check both lazy_tensor_ptr && *lazy_tensor_ptr, so everywhere that
67 // used to rely on a LazyTensor obj with a null Data can now rely on a null
68 // LazyTensorPtr instead.
69 LazyTensor() = delete;
70
71 ~LazyTensor() override = default;
72
73 size_t generation() const {
74 return data()->generation;
75 }
76
77 // Override it to use your own Shape.
78 virtual int64_t size(int64_t dim) const;
79
80 // Override it to use your own graph executor.
81 virtual at::Tensor ToTensor(bool detached);
82
83 void ShallowCopyTo(LazyTensorPtr dest) const;
84
85 // Assigns the tensor value to the lazy tensor.
86 void SetTensor(at::Tensor tensor);
87
88 void UpdateFromTensor(at::Tensor tensor, bool sync);
89 void UpdateFromTensorOut(at::Tensor tensor);
90 void UpdateFromTensorOut(const LazyTensorPtr& tensor);
91
92 const std::shared_ptr<Data>& data() const;
93
94 // Override it to use your own type conversion.
95 virtual at::ScalarType dtype() const;
96
97 MaybeRef<Shape> shape() const;
98
99 const BackendDevice& GetDevice() const;
100 int64_t GetUniqueId() const;
101
102 // Fetches the data behind the tensor. If the tensor has a graph defining
103 // its current value, executes the graph and fetches the data result.
104 BackendDataPtr GetDataHandle();
105
106 // Fetches the current value of the data, which can be missing (nullptr)
107 // in case the tensor has a graph defining its current value,
108 BackendDataPtr CurrentDataHandle() const;
109
110 void SetDataHandle(BackendDataPtr handle);
111 void SetDataHandle(BackendDataPtr handle, bool sync);
112
113 // Retrieves the current IR Node, or nullptr in case no active IR Node is
114 // available.
115 Value CurrentIrValue() const;
116
117 // Retrieves the IR Node representing this LazyTensor. One will be created if
118 // missing. Note that although this is a const API, it actually changes the
119 // internal state ofthe object.
120 Value GetIrValue() const;
121
122 void SetIrValue(Value ir_value);
123 void SetInPlaceIrValue(Value ir_value);
124
125 c10::optional<at::Tensor> CurrentTensorData() const;
126
127 std::vector<LazyTensorPtr> MakeOutputTensors(NodePtr node) const;
128
129 LazyTensorPtr CopyTensorToDevice(const BackendDevice& device);
130
131 // Applies the queue of operations in preparation for using the data.
132 // Override it to use your own graph executor.
133 virtual void ApplyPendingGraph();
134
135 // Override it to set extra information.
136 virtual void AssignIrValue(Value ir_value) const;
137
138 protected:
139 explicit LazyTensor(std::shared_ptr<Data> data);
140
141 void SetTensorData(at::Tensor tensor_data);
142
143 // We build a graph accumulating operations, but at a given point we
144 // need to force a rendering, otherwise the graph can grow without control.
145 // Think:
146 // for i in range(0, 100000):
147 // a = a + b
148 void TryLimitGraphSize();
149
150 // Override it to instantiate your own data.
151 virtual Value GetIrValueForTensor(
152 const at::Tensor& tensor,
153 const BackendDevice& device) const;
154
155 Value CreateTensorNode(BackendDataPtr data, bool read_only) const;
156
157 private:
158 LazyTensor(const at::Tensor& tensor, const BackendDevice& device);
159 LazyTensor(Value ir_value, const BackendDevice& device);
160 explicit LazyTensor(BackendDataPtr handle);
161
162 static int64_t GetNextTensorId();
163
164 std::shared_ptr<Data> data_;
165};
166
167// Utils to convert at::Tensor to LazyTensor, and vice versa.
168
169// Section 0: c10::Tensorlist ==> lazy::TensorList
170// note: GetTensorList is not totally parallel to GetLtcTensor; A TensorList
171// skips
172// the LazyTensor wrappers, assuming that the list of underlying IR nodes
173// is actually more useful for downstream computations. TBD.
174TORCH_API torch::lazy::Value GetTensorList(at::ITensorListRef tensors);
175
176// Section 1: at::Tensor => LazyTensor.
177// Extracts the LazyTensor out of an at::Tensor. Returns a null LazyTensor
178// if the tensor is not a lazy tensor.
179TORCH_API LazyTensorPtr TryGetLtcTensor(const at::Tensor& tensor);
180
181// Extracts the LazyTensor out of an at::Tensor. Throws an exception
182// if the tensor is not a lazy tensor.
183TORCH_API LazyTensorPtr GetLtcTensor(const at::Tensor& tensor);
184
185// Same as above, applied to a list of tensors.
186TORCH_API std::vector<LazyTensorPtr> GetLtcTensors(
187 c10::ArrayRef<at::Tensor> tensors);
188
189// If tensor is a lazy tensor type, returns the LazyTensor embedded within it,
190// otherwise creates a new lazy tensor type with tensor as data.
191TORCH_API LazyTensorPtr GetOrCreateLtcTensor(
192 const c10::optional<at::Tensor>& tensor,
193 const BackendDevice& device);
194
195TORCH_API LazyTensorPtr GetLtcTensorOrCreateForWrappedNumber(
196 const at::Tensor& tensor,
197 const BackendDevice& device);
198
199// Section 2: LazyTensor => at::Tensor.
200// Creates an ATen tensor from an LazyTensor.
201TORCH_API at::Tensor CreateAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor);
202TORCH_API at::Tensor CreateAtenFromLtcTensor(LazyTensor&& ltc_tensor);
203
204// Note [Lazy Tensor Functionalization]
205// The functionalization pass is implemented by wrapping all TensorImpl
206// objects in C++ with an extra FunctionalTensorWrapper object,
207// that knows how to perform functionalization
208//
209// Certain functions in the aten API serve as entry/exit points for
210// functionalization, where we need to perform the wrapping/unwrapping:
211// - aten::to.device
212// - aten::empty
213
214// Given a non-lazy tensor, this function creates a lazy tensor on the specified
215// (lazy) device. The functionalize_output determines whether or not we should
216// wrap the output in a "functional wrapper".
217//
218// How do you know whether to pass true/false for functionalize_output?
219//
220// Case 1: nonlazy -> lazy
221// If you're implementing a function that takes in nonlazy tensors and returns
222// lazy tensors, then you should think of that function as an "entrypoint" to
223// functionalization, and use functionalize_output=true Examples include:
224// - factory functions (the LTC kernel for at::empty)
225// - CPU -> Lazy device converions (the LTC kernel for at::to_device)
226//
227// Case 2: lazy -> lazy
228// If you're implementing a function that takes in lazy tensors and returns
229// lazy tensors,
230// **but** requires creating lazy tensors internally,
231// then you can assume that the current function is running inside of some
232// outer context where functionalization is already running, that will take
233// care of doing the wrapping for you, and use functionalize_output=true
234// Examples include:
235// - CPU fallback (takes in lazy tensors, converts to cpu, calls kernel,
236// converts returns back to lazy tensors).
237TORCH_API at::Tensor to_lazy_tensor(
238 const at::Tensor& self,
239 const c10::TensorOptions& options,
240 at::Device device,
241 bool non_blocking,
242 bool functionalize_output);
243
244template <size_t... Indices>
245auto TupleAtenFromLtcTensorsImpl(
246 const std::vector<LazyTensorPtr>& tensors,
247 std::index_sequence<Indices...>) {
248 return std::make_tuple(CreateAtenFromLtcTensor(tensors[Indices])...);
249}
250
251template <size_t N>
252auto TupleAtenFromLtcTensors(const std::vector<LazyTensorPtr>& tensors) {
253 return TupleAtenFromLtcTensorsImpl(tensors, std::make_index_sequence<N>{});
254}
255
256} // namespace lazy
257} // namespace torch
258