1 | #include <torch/csrc/lazy/core/config.h> |
2 | #include <torch/csrc/lazy/core/tensor.h> |
3 | |
4 | #include <c10/util/irange.h> |
5 | #include <torch/csrc/lazy/core/helpers.h> |
6 | #include <torch/csrc/lazy/core/ir_builder.h> |
7 | #include <torch/csrc/lazy/core/ir_dump_util.h> |
8 | #include <torch/csrc/lazy/core/lazy_graph_executor.h> |
9 | #include <torch/csrc/lazy/core/metrics.h> |
10 | #include <torch/csrc/lazy/core/tensor_impl.h> |
11 | #include <torch/csrc/lazy/core/tensor_util.h> |
12 | |
13 | #include <ATen/FunctionalTensorWrapper.h> |
14 | |
15 | namespace torch { |
16 | namespace lazy { |
17 | namespace { |
18 | LazyTensorPtr GetOrCreateLtcTensor( |
19 | const at::Tensor& tensor, |
20 | const BackendDevice& device) { |
21 | if (!tensor.defined()) { |
22 | return torch::lazy::LazyTensorPtr(); |
23 | } |
24 | auto lazy_tensor = TryGetLtcTensor(tensor); |
25 | return lazy_tensor ? lazy_tensor : LazyTensor::Create(tensor, device); |
26 | } |
27 | } // namespace |
28 | |
29 | LazyTensor::Data::~Data() { |
30 | LazyGraphExecutor::Get()->UnregisterTensor(this); |
31 | } |
32 | |
33 | LazyTensorPtr LazyTensor::Create( |
34 | const at::Tensor& tensor, |
35 | const BackendDevice& device) { |
36 | TORCH_CHECK(tensor.device().type() != at::kLazy); |
37 | LazyTensorPtr lazy_tensor = |
38 | c10::make_intrusive<LazyTensor>(LazyTensor(tensor, device)); |
39 | LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data()); |
40 | return lazy_tensor; |
41 | } |
42 | |
43 | LazyTensorPtr LazyTensor::Create(Value ir_value, const BackendDevice& device) { |
44 | LazyTensorPtr lazy_tensor = |
45 | c10::make_intrusive<LazyTensor>(LazyTensor(std::move(ir_value), device)); |
46 | LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data()); |
47 | return lazy_tensor; |
48 | } |
49 | |
50 | LazyTensorPtr LazyTensor::Create(BackendDataPtr handle) { |
51 | LazyTensorPtr lazy_tensor = |
52 | c10::make_intrusive<LazyTensor>(LazyTensor(std::move(handle))); |
53 | LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data()); |
54 | return lazy_tensor; |
55 | } |
56 | |
57 | LazyTensorPtr LazyTensor::Create(std::shared_ptr<Data> data) { |
58 | return c10::make_intrusive<LazyTensor>(LazyTensor(std::move(data))); |
59 | } |
60 | |
61 | LazyTensor::LazyTensor(const at::Tensor& tensor, const BackendDevice& device) |
62 | : LazyTensor(std::make_shared<Data>(tensor, device)) {} |
63 | |
64 | LazyTensor::LazyTensor(BackendDataPtr handle) |
65 | : LazyTensor(std::make_shared<Data>(handle, handle->device())) {} |
66 | |
67 | LazyTensor::LazyTensor(Value ir_value, const BackendDevice& device) |
68 | : LazyTensor(std::make_shared<Data>(std::move(ir_value), device)) { |
69 | TryLimitGraphSize(); |
70 | } |
71 | |
72 | LazyTensor::LazyTensor(std::shared_ptr<Data> data) : data_(std::move(data)) {} |
73 | |
74 | auto LazyTensor::data() const -> const std::shared_ptr<Data>& { |
75 | TORCH_CHECK(data_ != nullptr, "Trying to access a null cursor" ); |
76 | return data_; |
77 | } |
78 | |
79 | int64_t LazyTensor::size(int64_t dim) const { |
80 | auto tensor_shape = shape(); |
81 | int rank = tensor_shape.Get().dim(); |
82 | int dim_index = GetCanonicalDimensionIndex(dim, rank); |
83 | return tensor_shape.Get().size(dim_index); |
84 | } |
85 | |
86 | at::ScalarType LazyTensor::dtype() const { |
87 | return shape().Get().scalar_type(); |
88 | } |
89 | |
90 | MaybeRef<Shape> LazyTensor::shape() const { |
91 | if (data()->handle != nullptr) { |
92 | return Shape(data()->handle->shape()); |
93 | } |
94 | if (data()->ir_value) { |
95 | // TODO(whc) remove shape from LazyTensor API too! |
96 | return data()->ir_value.shape(); |
97 | } |
98 | TORCH_CHECK(data()->tensor_data); |
99 | return Shape( |
100 | data()->tensor_data->scalar_type(), |
101 | ToI64Vector(data()->tensor_data->sizes())); |
102 | } |
103 | |
104 | const BackendDevice& LazyTensor::GetDevice() const { |
105 | return data()->device; |
106 | } |
107 | |
108 | int64_t LazyTensor::GetUniqueId() const { |
109 | return data()->unique_id; |
110 | } |
111 | |
112 | BackendDataPtr LazyTensor::GetDataHandle() { |
113 | BackendDataPtr handle = CurrentDataHandle(); |
114 | if (handle != nullptr) { |
115 | TORCH_CHECK( |
116 | handle->HasValue(), |
117 | "Trying to access data while an async operation is in flight: " , |
118 | handle->shape().to_string()); |
119 | return handle; |
120 | } |
121 | |
122 | if (data()->ir_value) { |
123 | ApplyPendingGraph(); |
124 | } else { |
125 | TORCH_CHECK(data()->tensor_data); |
126 | data()->handle = TensorToDataHandle(*data()->tensor_data, GetDevice()); |
127 | } |
128 | |
129 | return data()->handle; |
130 | } |
131 | |
132 | BackendDataPtr LazyTensor::CurrentDataHandle() const { |
133 | return data()->handle; |
134 | } |
135 | |
136 | void LazyTensor::SetDataHandle(BackendDataPtr handle) { |
137 | SetDataHandle(std::move(handle), /*sync=*/true); |
138 | } |
139 | |
140 | void LazyTensor::SetDataHandle(BackendDataPtr handle, bool sync) { |
141 | data()->handle = std::move(handle); |
142 | // Assigning a device data should always clear the IR node, to allow graph |
143 | // trimming. |
144 | AssignIrValue(Value()); |
145 | if (sync) { |
146 | data()->tensor_data = c10::nullopt; |
147 | } |
148 | } |
149 | |
150 | void LazyTensor::SetIrValue(Value ir_value) { |
151 | data()->handle = nullptr; |
152 | data()->tensor_data = c10::nullopt; |
153 | AssignIrValue(std::move(ir_value)); |
154 | TryLimitGraphSize(); |
155 | } |
156 | |
157 | void LazyTensor::SetInPlaceIrValue(Value ir_value) { |
158 | auto tensor_shape = shape(); |
159 | if (tensor_shape.Get().scalar_type() != ir_value.shape().scalar_type()) { |
160 | ir_value = |
161 | MakeCast(ir_value, tensor_shape.Get().scalar_type(), c10::nullopt); |
162 | } |
163 | SetIrValue(std::move(ir_value)); |
164 | } |
165 | |
166 | void LazyTensor::AssignIrValue(Value ir_value) const { |
167 | data()->ir_value = std::move(ir_value); |
168 | data()->generation += 1; |
169 | } |
170 | |
171 | void LazyTensor::TryLimitGraphSize() { |
172 | if (data()->ir_value && |
173 | LazyGraphExecutor::Get()->IncTrimCounter() % |
174 | FLAGS_torch_lazy_trim_graph_check_frequency == |
175 | 0) { |
176 | size_t graph_size = Util::GetGraphSize({data()->ir_value.node.get()}); |
177 | if (graph_size > FLAGS_torch_lazy_trim_graph_size) { |
178 | TORCH_LAZY_COUNTER("TrimIrGraph" , 1); |
179 | ApplyPendingGraph(); |
180 | } |
181 | } |
182 | } |
183 | |
184 | Value LazyTensor::GetIrValue() const { |
185 | Value ir_value = CurrentIrValue(); |
186 | if (ir_value) { |
187 | return ir_value; |
188 | } |
189 | BackendDataPtr handle = CurrentDataHandle(); |
190 | if (handle != nullptr) { |
191 | // In case of tensor node, we do not clear the data when we set the IR |
192 | // node. This because we want further calls to GetIrValue() to fetch the |
193 | // same IR node, and not create new ones (even though the lowering context |
194 | // will still collapse them all into a single parameter op). So the call |
195 | // which wants the data will still find it, w/out having to fetch it via |
196 | // a computation client from-server call. |
197 | AssignIrValue(CreateTensorNode(handle, /*read_only=*/false)); |
198 | return data()->ir_value; |
199 | } |
200 | c10::optional<at::Tensor> tensor_data = CurrentTensorData(); |
201 | TORCH_CHECK(tensor_data); |
202 | AssignIrValue(GetIrValueForTensor(*tensor_data, GetDevice())); |
203 | return data()->ir_value; |
204 | } |
205 | |
206 | Value LazyTensor::CurrentIrValue() const { |
207 | return data()->ir_value; |
208 | } |
209 | |
210 | void LazyTensor::SetTensorData(at::Tensor tensor_data) { |
211 | data()->tensor_data = std::move(tensor_data); |
212 | } |
213 | |
214 | c10::optional<at::Tensor> LazyTensor::CurrentTensorData() const { |
215 | return data()->tensor_data; |
216 | } |
217 | |
218 | Value LazyTensor::GetIrValueForTensor( |
219 | const at::Tensor& tensor, |
220 | const BackendDevice& device) const { |
221 | BackendDataPtr data; |
222 | bool read_only = false; |
223 | if (tensor.dim() == 0 && tensor.numel() == 1) { |
224 | at::Scalar value = tensor.item(); |
225 | if (IsSpecialScalar(value)) { |
226 | return MakeScalar(value, tensor.scalar_type()); |
227 | } |
228 | data = LazyGraphExecutor::Get()->GetDeviceData(tensor.cpu(), device); |
229 | read_only = true; |
230 | } else { |
231 | TORCH_LAZY_TIMED("IrValueTensorToDataHandle" ); |
232 | data = TensorToDataHandle(tensor, device); |
233 | } |
234 | return CreateTensorNode(std::move(data), read_only); |
235 | } |
236 | |
237 | at::Tensor LazyTensor::ToTensor(bool detached) { |
238 | at::Tensor tensor; |
239 | c10::optional<at::Tensor> tensor_data = CurrentTensorData(); |
240 | if (!tensor_data) { |
241 | LazyGraphExecutor::Get()->DeviceBarrier(GetDevice()); |
242 | // The GetDataHandle() call will trigger an ApplyPendingGraph() if an IR |
243 | // Node is available on the tensor. |
244 | std::vector<at::Tensor> tensors = |
245 | DataHandlesToTensors({GetDataHandle()}, dtype()); |
246 | tensor = std::move(tensors.front()); |
247 | if (!detached) { |
248 | SetTensorData(tensor); |
249 | } |
250 | } else { |
251 | tensor = *tensor_data; |
252 | if (detached) { |
253 | if (data()->ir_value || data()->handle != nullptr) { |
254 | // If we have other authoritive sources, just drop our reference and |
255 | // transfer it to the caller. |
256 | data()->tensor_data = c10::nullopt; |
257 | } else { |
258 | // Otherwise we need to make a copy to prevent the caller changing our |
259 | // version. |
260 | tensor = CopyTensor(tensor); |
261 | } |
262 | } |
263 | } |
264 | return tensor; |
265 | } |
266 | |
267 | void LazyTensor::ShallowCopyTo(LazyTensorPtr dest) const { |
268 | dest->SetIrValue(GetIrValue()); |
269 | } |
270 | |
271 | void LazyTensor::SetTensor(at::Tensor tensor) { |
272 | SetTensorData(tensor); |
273 | data()->handle = nullptr; |
274 | AssignIrValue(Value()); |
275 | } |
276 | |
277 | void LazyTensor::UpdateFromTensor(at::Tensor tensor, bool sync) { |
278 | if (sync) { |
279 | at::Tensor typed_tensor = CopyTensor(tensor, dtype(), /*copy=*/false); |
280 | SetIrValue(GetIrValueForTensor(typed_tensor, GetDevice())); |
281 | } else { |
282 | SetTensorData(tensor); |
283 | data()->handle = nullptr; |
284 | AssignIrValue(Value()); |
285 | } |
286 | } |
287 | |
288 | void LazyTensor::UpdateFromTensorOut(at::Tensor tensor) { |
289 | UpdateFromTensor(std::move(tensor), /*sync=*/false); |
290 | } |
291 | |
292 | void LazyTensor::UpdateFromTensorOut(const LazyTensorPtr& tensor) { |
293 | SetIrValue(tensor->GetIrValue()); |
294 | } |
295 | |
296 | Value LazyTensor::CreateTensorNode(BackendDataPtr data, bool read_only) const { |
297 | data->SetInfo(std::make_shared<LazyGraphExecutor::DeviceDataInfo>( |
298 | GetUniqueId(), read_only)); |
299 | return MakeDeviceData(std::move(data)); |
300 | } |
301 | |
302 | std::vector<LazyTensorPtr> LazyTensor::MakeOutputTensors(NodePtr node) const { |
303 | std::vector<LazyTensorPtr> tensors; |
304 | tensors.reserve(node->num_outputs()); |
305 | for (const auto i : c10::irange(node->num_outputs())) { |
306 | tensors.push_back(Create(Value(node, i), GetDevice())); |
307 | } |
308 | return tensors; |
309 | } |
310 | |
311 | LazyTensorPtr LazyTensor::CopyTensorToDevice(const BackendDevice& device) { |
312 | // TODO: This can be optimized. |
313 | return Create(ToTensor(/*detached=*/true), device); |
314 | } |
315 | |
316 | void LazyTensor::ApplyPendingGraph() { |
317 | LazyGraphExecutor::Get()->DeviceBarrier(GetDevice()); |
318 | // This method is called to ensure that the tensor data is available on |
319 | // device, so that a call to CurrentDataHandle() returns a valid pointer. |
320 | if (CurrentDataHandle() == nullptr) { |
321 | std::vector<LazyTensorPtr> tensors( |
322 | {c10::make_intrusive<LazyTensor>(LazyTensor(*this))}); |
323 | LazyGraphExecutor::Get()->SyncTensorsGraph( |
324 | &tensors, |
325 | {}, |
326 | /*wait=*/true, |
327 | /*sync_ltc_data=*/false); |
328 | } |
329 | } |
330 | |
331 | int64_t LazyTensor::GetNextTensorId() { |
332 | static std::atomic<int64_t>* id_generator = new std::atomic<int64_t>(1); |
333 | return id_generator->fetch_add(1); |
334 | } |
335 | |
336 | torch::lazy::Value GetTensorList(at::ITensorListRef tensors) { |
337 | std::vector<Value> values; |
338 | for (const auto& t : tensors) { |
339 | auto* impl = dynamic_cast<LTCTensorImpl*>(t.unsafeGetTensorImpl()); |
340 | TORCH_INTERNAL_ASSERT( |
341 | impl, |
342 | "GetTensorList only supports lists of valid tensors, but optional support could be added" ); |
343 | values.push_back(impl->tensor()->GetIrValue()); |
344 | } |
345 | |
346 | return torch::lazy::Value(torch::lazy::MakeTensorList(std::move(values))); |
347 | } |
348 | |
349 | LazyTensorPtr TryGetLtcTensor(const at::Tensor& tensor) { |
350 | auto* impl = dynamic_cast<LTCTensorImpl*>( |
351 | maybe_unwrap_functional(tensor).unsafeGetTensorImpl()); |
352 | if (impl == nullptr) { |
353 | // return c10::make_intrusive<LazyTensor>(); |
354 | return LazyTensorPtr(); |
355 | } |
356 | return impl->tensor(); |
357 | } |
358 | |
359 | LazyTensorPtr GetLtcTensor(const at::Tensor& tensor) { |
360 | auto lazy_tensor = TryGetLtcTensor(tensor); |
361 | CHECK(lazy_tensor) << "Input tensor is not a lazy tensor: " |
362 | << tensor.toString(); |
363 | return lazy_tensor; |
364 | } |
365 | |
366 | std::vector<LazyTensorPtr> GetLtcTensors(c10::ArrayRef<at::Tensor> tensors) { |
367 | std::vector<LazyTensorPtr> ltc_tensors; |
368 | ltc_tensors.reserve(tensors.size()); |
369 | for (const auto& tensor : tensors) { |
370 | ltc_tensors.emplace_back(TryGetLtcTensor(tensor)); |
371 | } |
372 | return ltc_tensors; |
373 | } |
374 | |
375 | LazyTensorPtr GetOrCreateLtcTensor( |
376 | const c10::optional<at::Tensor>& tensor, |
377 | const BackendDevice& device) { |
378 | return GetOrCreateLtcTensor(tensor.value_or(at::Tensor()), device); |
379 | } |
380 | |
381 | LazyTensorPtr GetLtcTensorOrCreateForWrappedNumber( |
382 | const at::Tensor& tensor, |
383 | const BackendDevice& device) { |
384 | // TODO: There are places in core where a scalar is wrapped but not marked as |
385 | // wrapped. |
386 | return (tensor.unsafeGetTensorImpl()->is_wrapped_number() || |
387 | (tensor.dim() == 0 && tensor.numel() == 1)) |
388 | ? GetOrCreateLtcTensor(tensor, device) |
389 | : GetLtcTensor(tensor); |
390 | } |
391 | |
392 | at::Tensor CreateAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor) { |
393 | return ltc_tensor ? at::Tensor(c10::make_intrusive<LTCTensorImpl>(ltc_tensor)) |
394 | : at::Tensor(); |
395 | } |
396 | |
397 | at::Tensor CreateAtenFromLtcTensor(LazyTensor&& ltc_tensor) { |
398 | return at::Tensor(c10::make_intrusive<LTCTensorImpl>(std::move(ltc_tensor))); |
399 | } |
400 | |
401 | at::Tensor to_lazy_tensor( |
402 | const at::Tensor& self, |
403 | const c10::TensorOptions& options, |
404 | at::Device device, |
405 | bool non_blocking, |
406 | bool functionalize_output) { |
407 | TORCH_INTERNAL_ASSERT(self.device().type() != c10::kLazy); |
408 | TORCH_INTERNAL_ASSERT(device.type() == c10::kLazy); |
409 | |
410 | auto eager_tensor = |
411 | self.to(options, /*non_blocking=*/non_blocking, /*copy=*/true); |
412 | auto lazy_self = torch::lazy::GetOrCreateLtcTensor( |
413 | eager_tensor, torch::lazy::atenDeviceToBackendDevice(device)); |
414 | auto out = torch::lazy::CreateAtenFromLtcTensor(lazy_self); |
415 | if (functionalize_output) { |
416 | // See Note [Lazy Tensor Functionalization] |
417 | return at::functionalization::impl::to_functional_tensor(out); |
418 | } else { |
419 | return out; |
420 | } |
421 | } |
422 | |
423 | } // namespace lazy |
424 | } // namespace torch |
425 | |