1#include <torch/csrc/lazy/core/tensor_impl.h>
2
3#include <c10/core/Allocator.h>
4#include <c10/core/ScalarType.h>
5#include <c10/core/impl/DeviceGuardImplInterface.h>
6#include <c10/macros/Macros.h>
7#include <c10/util/irange.h>
8#include <torch/csrc/lazy/core/ir_builder.h>
9#include <torch/csrc/lazy/core/tensor_util.h>
10
11namespace torch {
12namespace lazy {
13namespace {
14
15// LTCGuardImpl is used by CompositeExplicitAutograd ops or eager fallbacks to
16// make sure that some particular tensors within the life scope of the guard are
17// on the same device. For example, in RegisterCompositeExplicitAutograd.cpp,
18// outputs of each op are examined if they are on same device as the supplied
19// TensorOptions. For more information, see DeviceGuard.h. For ops that have LTC
20// native function implementations, this guard is omitted.
21thread_local c10::Device g_device(c10::DeviceType::Lazy);
22
23struct LTCGuardImpl : public c10::impl::DeviceGuardImplInterface {
24 at::DeviceType type() const override {
25 return at::DeviceType::Lazy;
26 }
27
28 c10::Device exchangeDevice(c10::Device device) const override {
29 TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy);
30 auto old_device = g_device;
31 g_device = device;
32 return old_device;
33 }
34
35 c10::Device getDevice() const override {
36 return g_device;
37 }
38
39 void setDevice(c10::Device device) const override {
40 TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy);
41 g_device = device;
42 }
43
44 void uncheckedSetDevice(c10::Device device) const noexcept override {
45 TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy);
46 g_device = device;
47 }
48
49 c10::Stream getStream(c10::Device device) const noexcept override {
50 TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy);
51 return c10::Stream(c10::Stream::DEFAULT, device);
52 }
53
54 c10::Stream exchangeStream(c10::Stream _unused) const noexcept override {
55 return c10::Stream(c10::Stream::DEFAULT, g_device);
56 }
57
58 c10::DeviceIndex deviceCount() const noexcept override {
59 // This will get called when autograd initializes its device pool
60 // regardless whether we have a backend registered aforehand.
61 if (!hasBackend()) {
62 return 0;
63 }
64
65 return getBackend()->GetBackendDevices().size();
66 }
67};
68
69C10_REGISTER_GUARD_IMPL(Lazy, LTCGuardImpl);
70
71} // namespace
72
73// TODO(whc) when do we want to clone vs share?
74LTCTensorImpl::LTCTensorImpl(const LazyTensorPtr& tensor)
75 : LTCTensorImpl(LazyTensor(*tensor)) {}
76
77LTCTensorImpl::LTCTensorImpl(const LazyTensor& tensor)
78 : LTCTensorImpl(LazyTensor(tensor)) {}
79
80LTCTensorImpl::LTCTensorImpl(LazyTensor&& tensor)
81 : c10::TensorImpl(
82 c10::DispatchKeySet{
83 c10::DispatchKey::Lazy,
84 c10::DispatchKey::AutogradLazy},
85 c10::scalarTypeToTypeMeta(tensor.dtype()),
86 backendDeviceToAtenDevice(tensor.GetDevice())),
87 tensor_(c10::make_intrusive<LazyTensor>(std::move(tensor))) {
88 set_custom_sizes_strides(SizesStridesPolicy::CustomSizes);
89}
90
91void LTCTensorImpl::set_tensor(const LazyTensorPtr& lazy_tensor) {
92 tensor_ = c10::make_intrusive<LazyTensor>(*lazy_tensor);
93 generation_ = 0;
94}
95
96c10::intrusive_ptr<c10::TensorImpl> LTCTensorImpl::shallow_copy_and_detach(
97 const c10::VariableVersion& version_counter,
98 bool allow_tensor_metadata_change) const {
99 auto impl = c10::make_intrusive<LTCTensorImpl>(tensor_);
100 copy_tensor_metadata(
101 /*src_impl=*/this,
102 /*dest_impl=*/impl.get(),
103 /*version_counter=*/version_counter,
104 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
105 return impl;
106}
107
108c10::intrusive_ptr<c10::TensorImpl> LTCTensorImpl::shallow_copy_and_detach(
109 c10::VariableVersion&& version_counter,
110 bool allow_tensor_metadata_change) const {
111 auto impl = c10::make_intrusive<LTCTensorImpl>(tensor_);
112 copy_tensor_metadata(
113 /*src_impl=*/this,
114 /*dest_impl=*/impl.get(),
115 /*version_counter=*/std::move(version_counter),
116 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
117 return impl;
118}
119
120void LTCTensorImpl::shallow_copy_from(
121 const c10::intrusive_ptr<TensorImpl>& impl) {
122 LTCTensorImpl* ltc_impl = dynamic_cast<LTCTensorImpl*>(impl.get());
123 TORCH_INTERNAL_ASSERT(ltc_impl);
124 copy_tensor_metadata(
125 /*src_impl=*/ltc_impl,
126 /*dest_impl=*/this,
127 /*version_counter=*/version_counter(),
128 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
129 ltc_impl->tensor_->ShallowCopyTo(tensor_);
130 generation_ = 0;
131}
132
133c10::SymIntArrayRef LTCTensorImpl::sym_strides_custom() const {
134 return c10::fromIntArrayRefKnownNonNegative(strides_custom());
135}
136
137c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const {
138 return c10::fromIntArrayRefKnownNonNegative(sizes_custom());
139}
140
141c10::SymInt LTCTensorImpl::sym_numel_custom() const {
142 return numel_custom();
143}
144
145void LTCTensorImpl::setup_size_properties() {
146 size_t generation = tensor_->generation();
147 if (generation != generation_) {
148 // Fill up the basic dimension data members which the base class
149 // implementation uses in its APIs.
150 auto shape = tensor_->shape();
151 // We can't call refresh_numel() given we override sizes() too.
152 numel_ = shape.Get().numel();
153 sizes_and_strides_.set_sizes(shape.Get().sizes());
154 // We can't call empty_tensor_restride(c10::MemoryFormat::Contiguous) given
155 // we override sizes() too.
156 std::vector<int64_t> updated_strides;
157 updated_strides = ComputeArrayStrides(shape.Get().sizes());
158 for (const auto i : c10::irange(updated_strides.size())) {
159 sizes_and_strides_.stride_at_unchecked(i) = updated_strides[i];
160 }
161 generation_ = generation;
162 }
163}
164
165at::IntArrayRef LTCTensorImpl::sizes_custom() const {
166 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
167 const_cast<LTCTensorImpl*>(this)->setup_size_properties();
168 return sizes_default();
169}
170
171at::IntArrayRef LTCTensorImpl::strides_custom() const {
172 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
173 const_cast<LTCTensorImpl*>(this)->setup_size_properties();
174 return strides_default();
175}
176
177int64_t LTCTensorImpl::dim_custom() const {
178 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
179 const_cast<LTCTensorImpl*>(this)->setup_size_properties();
180 return dim_default();
181}
182
183int64_t LTCTensorImpl::numel_custom() const {
184 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
185 const_cast<LTCTensorImpl*>(this)->setup_size_properties();
186 return numel_default();
187}
188
189int64_t LTCTensorImpl::storage_offset_custom() const {
190 return 0;
191}
192
193bool LTCTensorImpl::is_strides_like_custom(
194 c10::MemoryFormat memory_format) const {
195 TORCH_INTERNAL_ASSERT(memory_format != at::MemoryFormat::Contiguous);
196 return false;
197}
198
199bool LTCTensorImpl::is_non_overlapping_and_dense_custom() const {
200 // This should be true, but false as a temporary fix for a PyTorch core issue,
201 // according to https://github.com/pytorch/xla/pull/2682.
202 return false;
203}
204
205bool LTCTensorImpl::is_contiguous_custom(c10::MemoryFormat _unused) const {
206 // TODO(ezyang): I don't think this branch is actually necessary
207 // TODO(ezyang): I don't think this logic is right, shouldn't we pass on
208 // the memory format?
209 if (tensor_->CurrentTensorData()) {
210 return tensor_->CurrentTensorData()->is_contiguous();
211 }
212 // Only check that the storage is already contiguous.
213 CHECK(is_contiguous_) << "Non-contiguous storage for lazy tensor";
214 // TODO: I don't think logic is right, we should check the requested memory
215 // format before returning true
216 return true;
217}
218
219} // namespace lazy
220} // namespace torch
221