1 | #include <torch/csrc/lazy/core/tensor_impl.h> |
2 | |
3 | #include <c10/core/Allocator.h> |
4 | #include <c10/core/ScalarType.h> |
5 | #include <c10/core/impl/DeviceGuardImplInterface.h> |
6 | #include <c10/macros/Macros.h> |
7 | #include <c10/util/irange.h> |
8 | #include <torch/csrc/lazy/core/ir_builder.h> |
9 | #include <torch/csrc/lazy/core/tensor_util.h> |
10 | |
11 | namespace torch { |
12 | namespace lazy { |
13 | namespace { |
14 | |
15 | // LTCGuardImpl is used by CompositeExplicitAutograd ops or eager fallbacks to |
16 | // make sure that some particular tensors within the life scope of the guard are |
17 | // on the same device. For example, in RegisterCompositeExplicitAutograd.cpp, |
18 | // outputs of each op are examined if they are on same device as the supplied |
19 | // TensorOptions. For more information, see DeviceGuard.h. For ops that have LTC |
20 | // native function implementations, this guard is omitted. |
21 | thread_local c10::Device g_device(c10::DeviceType::Lazy); |
22 | |
23 | struct LTCGuardImpl : public c10::impl::DeviceGuardImplInterface { |
24 | at::DeviceType type() const override { |
25 | return at::DeviceType::Lazy; |
26 | } |
27 | |
28 | c10::Device exchangeDevice(c10::Device device) const override { |
29 | TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy); |
30 | auto old_device = g_device; |
31 | g_device = device; |
32 | return old_device; |
33 | } |
34 | |
35 | c10::Device getDevice() const override { |
36 | return g_device; |
37 | } |
38 | |
39 | void setDevice(c10::Device device) const override { |
40 | TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy); |
41 | g_device = device; |
42 | } |
43 | |
44 | void uncheckedSetDevice(c10::Device device) const noexcept override { |
45 | TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy); |
46 | g_device = device; |
47 | } |
48 | |
49 | c10::Stream getStream(c10::Device device) const noexcept override { |
50 | TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy); |
51 | return c10::Stream(c10::Stream::DEFAULT, device); |
52 | } |
53 | |
54 | c10::Stream exchangeStream(c10::Stream _unused) const noexcept override { |
55 | return c10::Stream(c10::Stream::DEFAULT, g_device); |
56 | } |
57 | |
58 | c10::DeviceIndex deviceCount() const noexcept override { |
59 | // This will get called when autograd initializes its device pool |
60 | // regardless whether we have a backend registered aforehand. |
61 | if (!hasBackend()) { |
62 | return 0; |
63 | } |
64 | |
65 | return getBackend()->GetBackendDevices().size(); |
66 | } |
67 | }; |
68 | |
69 | C10_REGISTER_GUARD_IMPL(Lazy, LTCGuardImpl); |
70 | |
71 | } // namespace |
72 | |
73 | // TODO(whc) when do we want to clone vs share? |
74 | LTCTensorImpl::LTCTensorImpl(const LazyTensorPtr& tensor) |
75 | : LTCTensorImpl(LazyTensor(*tensor)) {} |
76 | |
77 | LTCTensorImpl::LTCTensorImpl(const LazyTensor& tensor) |
78 | : LTCTensorImpl(LazyTensor(tensor)) {} |
79 | |
80 | LTCTensorImpl::LTCTensorImpl(LazyTensor&& tensor) |
81 | : c10::TensorImpl( |
82 | c10::DispatchKeySet{ |
83 | c10::DispatchKey::Lazy, |
84 | c10::DispatchKey::AutogradLazy}, |
85 | c10::scalarTypeToTypeMeta(tensor.dtype()), |
86 | backendDeviceToAtenDevice(tensor.GetDevice())), |
87 | tensor_(c10::make_intrusive<LazyTensor>(std::move(tensor))) { |
88 | set_custom_sizes_strides(SizesStridesPolicy::CustomSizes); |
89 | } |
90 | |
91 | void LTCTensorImpl::set_tensor(const LazyTensorPtr& lazy_tensor) { |
92 | tensor_ = c10::make_intrusive<LazyTensor>(*lazy_tensor); |
93 | generation_ = 0; |
94 | } |
95 | |
96 | c10::intrusive_ptr<c10::TensorImpl> LTCTensorImpl::shallow_copy_and_detach( |
97 | const c10::VariableVersion& version_counter, |
98 | bool allow_tensor_metadata_change) const { |
99 | auto impl = c10::make_intrusive<LTCTensorImpl>(tensor_); |
100 | copy_tensor_metadata( |
101 | /*src_impl=*/this, |
102 | /*dest_impl=*/impl.get(), |
103 | /*version_counter=*/version_counter, |
104 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); |
105 | return impl; |
106 | } |
107 | |
108 | c10::intrusive_ptr<c10::TensorImpl> LTCTensorImpl::shallow_copy_and_detach( |
109 | c10::VariableVersion&& version_counter, |
110 | bool allow_tensor_metadata_change) const { |
111 | auto impl = c10::make_intrusive<LTCTensorImpl>(tensor_); |
112 | copy_tensor_metadata( |
113 | /*src_impl=*/this, |
114 | /*dest_impl=*/impl.get(), |
115 | /*version_counter=*/std::move(version_counter), |
116 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); |
117 | return impl; |
118 | } |
119 | |
120 | void LTCTensorImpl::shallow_copy_from( |
121 | const c10::intrusive_ptr<TensorImpl>& impl) { |
122 | LTCTensorImpl* ltc_impl = dynamic_cast<LTCTensorImpl*>(impl.get()); |
123 | TORCH_INTERNAL_ASSERT(ltc_impl); |
124 | copy_tensor_metadata( |
125 | /*src_impl=*/ltc_impl, |
126 | /*dest_impl=*/this, |
127 | /*version_counter=*/version_counter(), |
128 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); |
129 | ltc_impl->tensor_->ShallowCopyTo(tensor_); |
130 | generation_ = 0; |
131 | } |
132 | |
133 | c10::SymIntArrayRef LTCTensorImpl::sym_strides_custom() const { |
134 | return c10::fromIntArrayRefKnownNonNegative(strides_custom()); |
135 | } |
136 | |
137 | c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const { |
138 | return c10::fromIntArrayRefKnownNonNegative(sizes_custom()); |
139 | } |
140 | |
141 | c10::SymInt LTCTensorImpl::sym_numel_custom() const { |
142 | return numel_custom(); |
143 | } |
144 | |
145 | void LTCTensorImpl::setup_size_properties() { |
146 | size_t generation = tensor_->generation(); |
147 | if (generation != generation_) { |
148 | // Fill up the basic dimension data members which the base class |
149 | // implementation uses in its APIs. |
150 | auto shape = tensor_->shape(); |
151 | // We can't call refresh_numel() given we override sizes() too. |
152 | numel_ = shape.Get().numel(); |
153 | sizes_and_strides_.set_sizes(shape.Get().sizes()); |
154 | // We can't call empty_tensor_restride(c10::MemoryFormat::Contiguous) given |
155 | // we override sizes() too. |
156 | std::vector<int64_t> updated_strides; |
157 | updated_strides = ComputeArrayStrides(shape.Get().sizes()); |
158 | for (const auto i : c10::irange(updated_strides.size())) { |
159 | sizes_and_strides_.stride_at_unchecked(i) = updated_strides[i]; |
160 | } |
161 | generation_ = generation; |
162 | } |
163 | } |
164 | |
165 | at::IntArrayRef LTCTensorImpl::sizes_custom() const { |
166 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
167 | const_cast<LTCTensorImpl*>(this)->setup_size_properties(); |
168 | return sizes_default(); |
169 | } |
170 | |
171 | at::IntArrayRef LTCTensorImpl::strides_custom() const { |
172 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
173 | const_cast<LTCTensorImpl*>(this)->setup_size_properties(); |
174 | return strides_default(); |
175 | } |
176 | |
177 | int64_t LTCTensorImpl::dim_custom() const { |
178 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
179 | const_cast<LTCTensorImpl*>(this)->setup_size_properties(); |
180 | return dim_default(); |
181 | } |
182 | |
183 | int64_t LTCTensorImpl::numel_custom() const { |
184 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
185 | const_cast<LTCTensorImpl*>(this)->setup_size_properties(); |
186 | return numel_default(); |
187 | } |
188 | |
189 | int64_t LTCTensorImpl::storage_offset_custom() const { |
190 | return 0; |
191 | } |
192 | |
193 | bool LTCTensorImpl::is_strides_like_custom( |
194 | c10::MemoryFormat memory_format) const { |
195 | TORCH_INTERNAL_ASSERT(memory_format != at::MemoryFormat::Contiguous); |
196 | return false; |
197 | } |
198 | |
199 | bool LTCTensorImpl::is_non_overlapping_and_dense_custom() const { |
200 | // This should be true, but false as a temporary fix for a PyTorch core issue, |
201 | // according to https://github.com/pytorch/xla/pull/2682. |
202 | return false; |
203 | } |
204 | |
205 | bool LTCTensorImpl::is_contiguous_custom(c10::MemoryFormat _unused) const { |
206 | // TODO(ezyang): I don't think this branch is actually necessary |
207 | // TODO(ezyang): I don't think this logic is right, shouldn't we pass on |
208 | // the memory format? |
209 | if (tensor_->CurrentTensorData()) { |
210 | return tensor_->CurrentTensorData()->is_contiguous(); |
211 | } |
212 | // Only check that the storage is already contiguous. |
213 | CHECK(is_contiguous_) << "Non-contiguous storage for lazy tensor" ; |
214 | // TODO: I don't think logic is right, we should check the requested memory |
215 | // format before returning true |
216 | return true; |
217 | } |
218 | |
219 | } // namespace lazy |
220 | } // namespace torch |
221 | |