1 | #pragma once |
2 | |
3 | #include <c10/core/SymNodeImpl.h> |
4 | #include <c10/util/intrusive_ptr.h> |
5 | #include <torch/csrc/lazy/backend/backend_data.h> |
6 | #include <torch/csrc/lazy/backend/backend_device.h> |
7 | #include <torch/csrc/lazy/core/ir.h> |
8 | #include <torch/csrc/lazy/core/util.h> |
9 | |
10 | namespace torch { |
11 | namespace lazy { |
12 | |
13 | class TORCH_API SymNodeImpl : public c10::SymNodeImpl { |
14 | public: |
15 | SymNodeImpl(NodePtr ptr) : node_(std::move(ptr)){}; |
16 | NodePtr node_; |
17 | }; |
18 | |
19 | class LazyTensor; |
20 | using LazyTensorPtr = c10::intrusive_ptr<LazyTensor>; |
21 | |
22 | class TORCH_API LazyTensor : public c10::intrusive_ptr_target { |
23 | public: |
24 | // This is the core lazy tensor data structure where all the tensor data is |
25 | // held. The lazy tensor is nothing more than a shared pointer to a Data |
26 | // object. |
27 | struct Data { |
28 | Data(BackendDataPtr handle, BackendDevice device) |
29 | : handle(std::move(handle)), |
30 | device(std::move(device)), |
31 | unique_id(GetNextTensorId()) {} |
32 | Data(Value ir_value, BackendDevice device) |
33 | : ir_value(std::move(ir_value)), |
34 | device(std::move(device)), |
35 | unique_id(GetNextTensorId()) {} |
36 | Data(at::Tensor tensor_data, BackendDevice device) |
37 | : tensor_data(std::move(tensor_data)), |
38 | device(std::move(device)), |
39 | unique_id(GetNextTensorId()) {} |
40 | // TODO(alanwaketan): Remove this ctor. This is a |
41 | // temporary ctor to ease XLA LTC migration. It depends on |
42 | // XLA's Functionalization integration. |
43 | Data(BackendDevice device) |
44 | : device(std::move(device)), unique_id(GetNextTensorId()) {} |
45 | |
46 | virtual ~Data(); |
47 | |
48 | BackendDataPtr handle; |
49 | Value ir_value; |
50 | c10::optional<at::Tensor> tensor_data; |
51 | const BackendDevice device; |
52 | const int64_t unique_id = 0; |
53 | size_t generation = 1; |
54 | }; |
55 | |
56 | static LazyTensorPtr Create( |
57 | const at::Tensor& tensor, |
58 | const BackendDevice& device); |
59 | static LazyTensorPtr Create(Value ir_value, const BackendDevice& device); |
60 | static LazyTensorPtr Create(BackendDataPtr handle); |
61 | static LazyTensorPtr Create(std::shared_ptr<Data> data); |
62 | |
63 | // The default ctor previously created a null LazyTensor (one with no 'data' |
64 | // obj). Creating a null LazyTensor is no longer possible, since the same can |
65 | // be achieved by creating a null LazyTensorPtr and it is way too confusing to |
66 | // have to check both lazy_tensor_ptr && *lazy_tensor_ptr, so everywhere that |
67 | // used to rely on a LazyTensor obj with a null Data can now rely on a null |
68 | // LazyTensorPtr instead. |
69 | LazyTensor() = delete; |
70 | |
71 | ~LazyTensor() override = default; |
72 | |
73 | size_t generation() const { |
74 | return data()->generation; |
75 | } |
76 | |
77 | // Override it to use your own Shape. |
78 | virtual int64_t size(int64_t dim) const; |
79 | |
80 | // Override it to use your own graph executor. |
81 | virtual at::Tensor ToTensor(bool detached); |
82 | |
83 | void ShallowCopyTo(LazyTensorPtr dest) const; |
84 | |
85 | // Assigns the tensor value to the lazy tensor. |
86 | void SetTensor(at::Tensor tensor); |
87 | |
88 | void UpdateFromTensor(at::Tensor tensor, bool sync); |
89 | void UpdateFromTensorOut(at::Tensor tensor); |
90 | void UpdateFromTensorOut(const LazyTensorPtr& tensor); |
91 | |
92 | const std::shared_ptr<Data>& data() const; |
93 | |
94 | // Override it to use your own type conversion. |
95 | virtual at::ScalarType dtype() const; |
96 | |
97 | MaybeRef<Shape> shape() const; |
98 | |
99 | const BackendDevice& GetDevice() const; |
100 | int64_t GetUniqueId() const; |
101 | |
102 | // Fetches the data behind the tensor. If the tensor has a graph defining |
103 | // its current value, executes the graph and fetches the data result. |
104 | BackendDataPtr GetDataHandle(); |
105 | |
106 | // Fetches the current value of the data, which can be missing (nullptr) |
107 | // in case the tensor has a graph defining its current value, |
108 | BackendDataPtr CurrentDataHandle() const; |
109 | |
110 | void SetDataHandle(BackendDataPtr handle); |
111 | void SetDataHandle(BackendDataPtr handle, bool sync); |
112 | |
113 | // Retrieves the current IR Node, or nullptr in case no active IR Node is |
114 | // available. |
115 | Value CurrentIrValue() const; |
116 | |
117 | // Retrieves the IR Node representing this LazyTensor. One will be created if |
118 | // missing. Note that although this is a const API, it actually changes the |
119 | // internal state ofthe object. |
120 | Value GetIrValue() const; |
121 | |
122 | void SetIrValue(Value ir_value); |
123 | void SetInPlaceIrValue(Value ir_value); |
124 | |
125 | c10::optional<at::Tensor> CurrentTensorData() const; |
126 | |
127 | std::vector<LazyTensorPtr> MakeOutputTensors(NodePtr node) const; |
128 | |
129 | LazyTensorPtr CopyTensorToDevice(const BackendDevice& device); |
130 | |
131 | // Applies the queue of operations in preparation for using the data. |
132 | // Override it to use your own graph executor. |
133 | virtual void ApplyPendingGraph(); |
134 | |
135 | // Override it to set extra information. |
136 | virtual void AssignIrValue(Value ir_value) const; |
137 | |
138 | protected: |
139 | explicit LazyTensor(std::shared_ptr<Data> data); |
140 | |
141 | void SetTensorData(at::Tensor tensor_data); |
142 | |
143 | // We build a graph accumulating operations, but at a given point we |
144 | // need to force a rendering, otherwise the graph can grow without control. |
145 | // Think: |
146 | // for i in range(0, 100000): |
147 | // a = a + b |
148 | void TryLimitGraphSize(); |
149 | |
150 | // Override it to instantiate your own data. |
151 | virtual Value GetIrValueForTensor( |
152 | const at::Tensor& tensor, |
153 | const BackendDevice& device) const; |
154 | |
155 | Value CreateTensorNode(BackendDataPtr data, bool read_only) const; |
156 | |
157 | private: |
158 | LazyTensor(const at::Tensor& tensor, const BackendDevice& device); |
159 | LazyTensor(Value ir_value, const BackendDevice& device); |
160 | explicit LazyTensor(BackendDataPtr handle); |
161 | |
162 | static int64_t GetNextTensorId(); |
163 | |
164 | std::shared_ptr<Data> data_; |
165 | }; |
166 | |
167 | // Utils to convert at::Tensor to LazyTensor, and vice versa. |
168 | |
169 | // Section 0: c10::Tensorlist ==> lazy::TensorList |
170 | // note: GetTensorList is not totally parallel to GetLtcTensor; A TensorList |
171 | // skips |
172 | // the LazyTensor wrappers, assuming that the list of underlying IR nodes |
173 | // is actually more useful for downstream computations. TBD. |
174 | TORCH_API torch::lazy::Value GetTensorList(at::ITensorListRef tensors); |
175 | |
176 | // Section 1: at::Tensor => LazyTensor. |
177 | // Extracts the LazyTensor out of an at::Tensor. Returns a null LazyTensor |
178 | // if the tensor is not a lazy tensor. |
179 | TORCH_API LazyTensorPtr TryGetLtcTensor(const at::Tensor& tensor); |
180 | |
181 | // Extracts the LazyTensor out of an at::Tensor. Throws an exception |
182 | // if the tensor is not a lazy tensor. |
183 | TORCH_API LazyTensorPtr GetLtcTensor(const at::Tensor& tensor); |
184 | |
185 | // Same as above, applied to a list of tensors. |
186 | TORCH_API std::vector<LazyTensorPtr> GetLtcTensors( |
187 | c10::ArrayRef<at::Tensor> tensors); |
188 | |
189 | // If tensor is a lazy tensor type, returns the LazyTensor embedded within it, |
190 | // otherwise creates a new lazy tensor type with tensor as data. |
191 | TORCH_API LazyTensorPtr GetOrCreateLtcTensor( |
192 | const c10::optional<at::Tensor>& tensor, |
193 | const BackendDevice& device); |
194 | |
195 | TORCH_API LazyTensorPtr GetLtcTensorOrCreateForWrappedNumber( |
196 | const at::Tensor& tensor, |
197 | const BackendDevice& device); |
198 | |
199 | // Section 2: LazyTensor => at::Tensor. |
200 | // Creates an ATen tensor from an LazyTensor. |
201 | TORCH_API at::Tensor CreateAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor); |
202 | TORCH_API at::Tensor CreateAtenFromLtcTensor(LazyTensor&& ltc_tensor); |
203 | |
204 | // Note [Lazy Tensor Functionalization] |
205 | // The functionalization pass is implemented by wrapping all TensorImpl |
206 | // objects in C++ with an extra FunctionalTensorWrapper object, |
207 | // that knows how to perform functionalization |
208 | // |
209 | // Certain functions in the aten API serve as entry/exit points for |
210 | // functionalization, where we need to perform the wrapping/unwrapping: |
211 | // - aten::to.device |
212 | // - aten::empty |
213 | |
214 | // Given a non-lazy tensor, this function creates a lazy tensor on the specified |
215 | // (lazy) device. The functionalize_output determines whether or not we should |
216 | // wrap the output in a "functional wrapper". |
217 | // |
218 | // How do you know whether to pass true/false for functionalize_output? |
219 | // |
220 | // Case 1: nonlazy -> lazy |
221 | // If you're implementing a function that takes in nonlazy tensors and returns |
222 | // lazy tensors, then you should think of that function as an "entrypoint" to |
223 | // functionalization, and use functionalize_output=true Examples include: |
224 | // - factory functions (the LTC kernel for at::empty) |
225 | // - CPU -> Lazy device converions (the LTC kernel for at::to_device) |
226 | // |
227 | // Case 2: lazy -> lazy |
228 | // If you're implementing a function that takes in lazy tensors and returns |
229 | // lazy tensors, |
230 | // **but** requires creating lazy tensors internally, |
231 | // then you can assume that the current function is running inside of some |
232 | // outer context where functionalization is already running, that will take |
233 | // care of doing the wrapping for you, and use functionalize_output=true |
234 | // Examples include: |
235 | // - CPU fallback (takes in lazy tensors, converts to cpu, calls kernel, |
236 | // converts returns back to lazy tensors). |
237 | TORCH_API at::Tensor to_lazy_tensor( |
238 | const at::Tensor& self, |
239 | const c10::TensorOptions& options, |
240 | at::Device device, |
241 | bool non_blocking, |
242 | bool functionalize_output); |
243 | |
244 | template <size_t... Indices> |
245 | auto TupleAtenFromLtcTensorsImpl( |
246 | const std::vector<LazyTensorPtr>& tensors, |
247 | std::index_sequence<Indices...>) { |
248 | return std::make_tuple(CreateAtenFromLtcTensor(tensors[Indices])...); |
249 | } |
250 | |
251 | template <size_t N> |
252 | auto TupleAtenFromLtcTensors(const std::vector<LazyTensorPtr>& tensors) { |
253 | return TupleAtenFromLtcTensorsImpl(tensors, std::make_index_sequence<N>{}); |
254 | } |
255 | |
256 | } // namespace lazy |
257 | } // namespace torch |
258 | |