1#include <torch/csrc/lazy/core/config.h>
2#include <torch/csrc/lazy/core/tensor.h>
3
4#include <c10/util/irange.h>
5#include <torch/csrc/lazy/core/helpers.h>
6#include <torch/csrc/lazy/core/ir_builder.h>
7#include <torch/csrc/lazy/core/ir_dump_util.h>
8#include <torch/csrc/lazy/core/lazy_graph_executor.h>
9#include <torch/csrc/lazy/core/metrics.h>
10#include <torch/csrc/lazy/core/tensor_impl.h>
11#include <torch/csrc/lazy/core/tensor_util.h>
12
13#include <ATen/FunctionalTensorWrapper.h>
14
15namespace torch {
16namespace lazy {
17namespace {
18LazyTensorPtr GetOrCreateLtcTensor(
19 const at::Tensor& tensor,
20 const BackendDevice& device) {
21 if (!tensor.defined()) {
22 return torch::lazy::LazyTensorPtr();
23 }
24 auto lazy_tensor = TryGetLtcTensor(tensor);
25 return lazy_tensor ? lazy_tensor : LazyTensor::Create(tensor, device);
26}
27} // namespace
28
29LazyTensor::Data::~Data() {
30 LazyGraphExecutor::Get()->UnregisterTensor(this);
31}
32
33LazyTensorPtr LazyTensor::Create(
34 const at::Tensor& tensor,
35 const BackendDevice& device) {
36 TORCH_CHECK(tensor.device().type() != at::kLazy);
37 LazyTensorPtr lazy_tensor =
38 c10::make_intrusive<LazyTensor>(LazyTensor(tensor, device));
39 LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data());
40 return lazy_tensor;
41}
42
43LazyTensorPtr LazyTensor::Create(Value ir_value, const BackendDevice& device) {
44 LazyTensorPtr lazy_tensor =
45 c10::make_intrusive<LazyTensor>(LazyTensor(std::move(ir_value), device));
46 LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data());
47 return lazy_tensor;
48}
49
50LazyTensorPtr LazyTensor::Create(BackendDataPtr handle) {
51 LazyTensorPtr lazy_tensor =
52 c10::make_intrusive<LazyTensor>(LazyTensor(std::move(handle)));
53 LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data());
54 return lazy_tensor;
55}
56
57LazyTensorPtr LazyTensor::Create(std::shared_ptr<Data> data) {
58 return c10::make_intrusive<LazyTensor>(LazyTensor(std::move(data)));
59}
60
61LazyTensor::LazyTensor(const at::Tensor& tensor, const BackendDevice& device)
62 : LazyTensor(std::make_shared<Data>(tensor, device)) {}
63
64LazyTensor::LazyTensor(BackendDataPtr handle)
65 : LazyTensor(std::make_shared<Data>(handle, handle->device())) {}
66
67LazyTensor::LazyTensor(Value ir_value, const BackendDevice& device)
68 : LazyTensor(std::make_shared<Data>(std::move(ir_value), device)) {
69 TryLimitGraphSize();
70}
71
72LazyTensor::LazyTensor(std::shared_ptr<Data> data) : data_(std::move(data)) {}
73
74auto LazyTensor::data() const -> const std::shared_ptr<Data>& {
75 TORCH_CHECK(data_ != nullptr, "Trying to access a null cursor");
76 return data_;
77}
78
79int64_t LazyTensor::size(int64_t dim) const {
80 auto tensor_shape = shape();
81 int rank = tensor_shape.Get().dim();
82 int dim_index = GetCanonicalDimensionIndex(dim, rank);
83 return tensor_shape.Get().size(dim_index);
84}
85
86at::ScalarType LazyTensor::dtype() const {
87 return shape().Get().scalar_type();
88}
89
90MaybeRef<Shape> LazyTensor::shape() const {
91 if (data()->handle != nullptr) {
92 return Shape(data()->handle->shape());
93 }
94 if (data()->ir_value) {
95 // TODO(whc) remove shape from LazyTensor API too!
96 return data()->ir_value.shape();
97 }
98 TORCH_CHECK(data()->tensor_data);
99 return Shape(
100 data()->tensor_data->scalar_type(),
101 ToI64Vector(data()->tensor_data->sizes()));
102}
103
104const BackendDevice& LazyTensor::GetDevice() const {
105 return data()->device;
106}
107
108int64_t LazyTensor::GetUniqueId() const {
109 return data()->unique_id;
110}
111
112BackendDataPtr LazyTensor::GetDataHandle() {
113 BackendDataPtr handle = CurrentDataHandle();
114 if (handle != nullptr) {
115 TORCH_CHECK(
116 handle->HasValue(),
117 "Trying to access data while an async operation is in flight: ",
118 handle->shape().to_string());
119 return handle;
120 }
121
122 if (data()->ir_value) {
123 ApplyPendingGraph();
124 } else {
125 TORCH_CHECK(data()->tensor_data);
126 data()->handle = TensorToDataHandle(*data()->tensor_data, GetDevice());
127 }
128
129 return data()->handle;
130}
131
132BackendDataPtr LazyTensor::CurrentDataHandle() const {
133 return data()->handle;
134}
135
136void LazyTensor::SetDataHandle(BackendDataPtr handle) {
137 SetDataHandle(std::move(handle), /*sync=*/true);
138}
139
140void LazyTensor::SetDataHandle(BackendDataPtr handle, bool sync) {
141 data()->handle = std::move(handle);
142 // Assigning a device data should always clear the IR node, to allow graph
143 // trimming.
144 AssignIrValue(Value());
145 if (sync) {
146 data()->tensor_data = c10::nullopt;
147 }
148}
149
150void LazyTensor::SetIrValue(Value ir_value) {
151 data()->handle = nullptr;
152 data()->tensor_data = c10::nullopt;
153 AssignIrValue(std::move(ir_value));
154 TryLimitGraphSize();
155}
156
157void LazyTensor::SetInPlaceIrValue(Value ir_value) {
158 auto tensor_shape = shape();
159 if (tensor_shape.Get().scalar_type() != ir_value.shape().scalar_type()) {
160 ir_value =
161 MakeCast(ir_value, tensor_shape.Get().scalar_type(), c10::nullopt);
162 }
163 SetIrValue(std::move(ir_value));
164}
165
166void LazyTensor::AssignIrValue(Value ir_value) const {
167 data()->ir_value = std::move(ir_value);
168 data()->generation += 1;
169}
170
171void LazyTensor::TryLimitGraphSize() {
172 if (data()->ir_value &&
173 LazyGraphExecutor::Get()->IncTrimCounter() %
174 FLAGS_torch_lazy_trim_graph_check_frequency ==
175 0) {
176 size_t graph_size = Util::GetGraphSize({data()->ir_value.node.get()});
177 if (graph_size > FLAGS_torch_lazy_trim_graph_size) {
178 TORCH_LAZY_COUNTER("TrimIrGraph", 1);
179 ApplyPendingGraph();
180 }
181 }
182}
183
184Value LazyTensor::GetIrValue() const {
185 Value ir_value = CurrentIrValue();
186 if (ir_value) {
187 return ir_value;
188 }
189 BackendDataPtr handle = CurrentDataHandle();
190 if (handle != nullptr) {
191 // In case of tensor node, we do not clear the data when we set the IR
192 // node. This because we want further calls to GetIrValue() to fetch the
193 // same IR node, and not create new ones (even though the lowering context
194 // will still collapse them all into a single parameter op). So the call
195 // which wants the data will still find it, w/out having to fetch it via
196 // a computation client from-server call.
197 AssignIrValue(CreateTensorNode(handle, /*read_only=*/false));
198 return data()->ir_value;
199 }
200 c10::optional<at::Tensor> tensor_data = CurrentTensorData();
201 TORCH_CHECK(tensor_data);
202 AssignIrValue(GetIrValueForTensor(*tensor_data, GetDevice()));
203 return data()->ir_value;
204}
205
206Value LazyTensor::CurrentIrValue() const {
207 return data()->ir_value;
208}
209
210void LazyTensor::SetTensorData(at::Tensor tensor_data) {
211 data()->tensor_data = std::move(tensor_data);
212}
213
214c10::optional<at::Tensor> LazyTensor::CurrentTensorData() const {
215 return data()->tensor_data;
216}
217
218Value LazyTensor::GetIrValueForTensor(
219 const at::Tensor& tensor,
220 const BackendDevice& device) const {
221 BackendDataPtr data;
222 bool read_only = false;
223 if (tensor.dim() == 0 && tensor.numel() == 1) {
224 at::Scalar value = tensor.item();
225 if (IsSpecialScalar(value)) {
226 return MakeScalar(value, tensor.scalar_type());
227 }
228 data = LazyGraphExecutor::Get()->GetDeviceData(tensor.cpu(), device);
229 read_only = true;
230 } else {
231 TORCH_LAZY_TIMED("IrValueTensorToDataHandle");
232 data = TensorToDataHandle(tensor, device);
233 }
234 return CreateTensorNode(std::move(data), read_only);
235}
236
237at::Tensor LazyTensor::ToTensor(bool detached) {
238 at::Tensor tensor;
239 c10::optional<at::Tensor> tensor_data = CurrentTensorData();
240 if (!tensor_data) {
241 LazyGraphExecutor::Get()->DeviceBarrier(GetDevice());
242 // The GetDataHandle() call will trigger an ApplyPendingGraph() if an IR
243 // Node is available on the tensor.
244 std::vector<at::Tensor> tensors =
245 DataHandlesToTensors({GetDataHandle()}, dtype());
246 tensor = std::move(tensors.front());
247 if (!detached) {
248 SetTensorData(tensor);
249 }
250 } else {
251 tensor = *tensor_data;
252 if (detached) {
253 if (data()->ir_value || data()->handle != nullptr) {
254 // If we have other authoritive sources, just drop our reference and
255 // transfer it to the caller.
256 data()->tensor_data = c10::nullopt;
257 } else {
258 // Otherwise we need to make a copy to prevent the caller changing our
259 // version.
260 tensor = CopyTensor(tensor);
261 }
262 }
263 }
264 return tensor;
265}
266
267void LazyTensor::ShallowCopyTo(LazyTensorPtr dest) const {
268 dest->SetIrValue(GetIrValue());
269}
270
271void LazyTensor::SetTensor(at::Tensor tensor) {
272 SetTensorData(tensor);
273 data()->handle = nullptr;
274 AssignIrValue(Value());
275}
276
277void LazyTensor::UpdateFromTensor(at::Tensor tensor, bool sync) {
278 if (sync) {
279 at::Tensor typed_tensor = CopyTensor(tensor, dtype(), /*copy=*/false);
280 SetIrValue(GetIrValueForTensor(typed_tensor, GetDevice()));
281 } else {
282 SetTensorData(tensor);
283 data()->handle = nullptr;
284 AssignIrValue(Value());
285 }
286}
287
288void LazyTensor::UpdateFromTensorOut(at::Tensor tensor) {
289 UpdateFromTensor(std::move(tensor), /*sync=*/false);
290}
291
292void LazyTensor::UpdateFromTensorOut(const LazyTensorPtr& tensor) {
293 SetIrValue(tensor->GetIrValue());
294}
295
296Value LazyTensor::CreateTensorNode(BackendDataPtr data, bool read_only) const {
297 data->SetInfo(std::make_shared<LazyGraphExecutor::DeviceDataInfo>(
298 GetUniqueId(), read_only));
299 return MakeDeviceData(std::move(data));
300}
301
302std::vector<LazyTensorPtr> LazyTensor::MakeOutputTensors(NodePtr node) const {
303 std::vector<LazyTensorPtr> tensors;
304 tensors.reserve(node->num_outputs());
305 for (const auto i : c10::irange(node->num_outputs())) {
306 tensors.push_back(Create(Value(node, i), GetDevice()));
307 }
308 return tensors;
309}
310
311LazyTensorPtr LazyTensor::CopyTensorToDevice(const BackendDevice& device) {
312 // TODO: This can be optimized.
313 return Create(ToTensor(/*detached=*/true), device);
314}
315
316void LazyTensor::ApplyPendingGraph() {
317 LazyGraphExecutor::Get()->DeviceBarrier(GetDevice());
318 // This method is called to ensure that the tensor data is available on
319 // device, so that a call to CurrentDataHandle() returns a valid pointer.
320 if (CurrentDataHandle() == nullptr) {
321 std::vector<LazyTensorPtr> tensors(
322 {c10::make_intrusive<LazyTensor>(LazyTensor(*this))});
323 LazyGraphExecutor::Get()->SyncTensorsGraph(
324 &tensors,
325 {},
326 /*wait=*/true,
327 /*sync_ltc_data=*/false);
328 }
329}
330
331int64_t LazyTensor::GetNextTensorId() {
332 static std::atomic<int64_t>* id_generator = new std::atomic<int64_t>(1);
333 return id_generator->fetch_add(1);
334}
335
336torch::lazy::Value GetTensorList(at::ITensorListRef tensors) {
337 std::vector<Value> values;
338 for (const auto& t : tensors) {
339 auto* impl = dynamic_cast<LTCTensorImpl*>(t.unsafeGetTensorImpl());
340 TORCH_INTERNAL_ASSERT(
341 impl,
342 "GetTensorList only supports lists of valid tensors, but optional support could be added");
343 values.push_back(impl->tensor()->GetIrValue());
344 }
345
346 return torch::lazy::Value(torch::lazy::MakeTensorList(std::move(values)));
347}
348
349LazyTensorPtr TryGetLtcTensor(const at::Tensor& tensor) {
350 auto* impl = dynamic_cast<LTCTensorImpl*>(
351 maybe_unwrap_functional(tensor).unsafeGetTensorImpl());
352 if (impl == nullptr) {
353 // return c10::make_intrusive<LazyTensor>();
354 return LazyTensorPtr();
355 }
356 return impl->tensor();
357}
358
359LazyTensorPtr GetLtcTensor(const at::Tensor& tensor) {
360 auto lazy_tensor = TryGetLtcTensor(tensor);
361 CHECK(lazy_tensor) << "Input tensor is not a lazy tensor: "
362 << tensor.toString();
363 return lazy_tensor;
364}
365
366std::vector<LazyTensorPtr> GetLtcTensors(c10::ArrayRef<at::Tensor> tensors) {
367 std::vector<LazyTensorPtr> ltc_tensors;
368 ltc_tensors.reserve(tensors.size());
369 for (const auto& tensor : tensors) {
370 ltc_tensors.emplace_back(TryGetLtcTensor(tensor));
371 }
372 return ltc_tensors;
373}
374
375LazyTensorPtr GetOrCreateLtcTensor(
376 const c10::optional<at::Tensor>& tensor,
377 const BackendDevice& device) {
378 return GetOrCreateLtcTensor(tensor.value_or(at::Tensor()), device);
379}
380
381LazyTensorPtr GetLtcTensorOrCreateForWrappedNumber(
382 const at::Tensor& tensor,
383 const BackendDevice& device) {
384 // TODO: There are places in core where a scalar is wrapped but not marked as
385 // wrapped.
386 return (tensor.unsafeGetTensorImpl()->is_wrapped_number() ||
387 (tensor.dim() == 0 && tensor.numel() == 1))
388 ? GetOrCreateLtcTensor(tensor, device)
389 : GetLtcTensor(tensor);
390}
391
392at::Tensor CreateAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor) {
393 return ltc_tensor ? at::Tensor(c10::make_intrusive<LTCTensorImpl>(ltc_tensor))
394 : at::Tensor();
395}
396
397at::Tensor CreateAtenFromLtcTensor(LazyTensor&& ltc_tensor) {
398 return at::Tensor(c10::make_intrusive<LTCTensorImpl>(std::move(ltc_tensor)));
399}
400
401at::Tensor to_lazy_tensor(
402 const at::Tensor& self,
403 const c10::TensorOptions& options,
404 at::Device device,
405 bool non_blocking,
406 bool functionalize_output) {
407 TORCH_INTERNAL_ASSERT(self.device().type() != c10::kLazy);
408 TORCH_INTERNAL_ASSERT(device.type() == c10::kLazy);
409
410 auto eager_tensor =
411 self.to(options, /*non_blocking=*/non_blocking, /*copy=*/true);
412 auto lazy_self = torch::lazy::GetOrCreateLtcTensor(
413 eager_tensor, torch::lazy::atenDeviceToBackendDevice(device));
414 auto out = torch::lazy::CreateAtenFromLtcTensor(lazy_self);
415 if (functionalize_output) {
416 // See Note [Lazy Tensor Functionalization]
417 return at::functionalization::impl::to_functional_tensor(out);
418 } else {
419 return out;
420 }
421}
422
423} // namespace lazy
424} // namespace torch
425