1 | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
2 | #include <ATen/native/Copy.h> |
3 | |
4 | #include <ATen/core/Tensor.h> |
5 | #include <ATen/Dispatch.h> |
6 | #include <ATen/FunctionalTensorWrapper.h> |
7 | #include <ATen/TensorIterator.h> |
8 | #include <ATen/native/quantized/Copy.h> |
9 | #include <ATen/native/mps/Copy.h> |
10 | #include <ATen/native/vulkan/ops/Copy.h> |
11 | #include <ATen/native/TensorShape.h> |
12 | #include <ATen/quantized/Quantizer.h> |
13 | #include <ATen/vulkan/Context.h> |
14 | #include <ATen/metal/Context.h> |
15 | #include <ATen/NamedTensorUtils.h> |
16 | #include <ATen/Parallel.h> |
17 | #include <c10/util/irange.h> |
18 | |
19 | #ifndef AT_PER_OPERATOR_HEADERS |
20 | #include <ATen/Functions.h> |
21 | #include <ATen/NativeFunctions.h> |
22 | #else |
23 | #include <ATen/ops/_copy_from.h> |
24 | #include <ATen/ops/copy_native.h> |
25 | #include <ATen/ops/empty.h> |
26 | #include <ATen/ops/expand_copy.h> |
27 | #endif |
28 | |
29 | #ifdef USE_FBGEMM |
30 | #include <fbgemm/Fbgemm.h> |
31 | #include <fbgemm/FbgemmConvert.h> |
32 | #endif |
33 | |
34 | namespace { |
35 | |
36 | using namespace at; |
37 | |
38 | bool copy_transpose_valid(const Tensor& self, const Tensor& src) { |
39 | const int MIN_SZ = 60 * 60; |
40 | return self.is_contiguous() && src.numel() != 0 && src.dim() == 2 && |
41 | src.stride(0) == 1 && src.stride(1) == src.size(0) && |
42 | self.scalar_type() == src.scalar_type() && |
43 | self.sizes().equals(src.sizes()) && |
44 | self.is_neg() == src.is_neg() && |
45 | self.is_conj() == src.is_conj() && |
46 | self.numel() >= MIN_SZ; |
47 | } |
48 | |
49 | // special case copy where tensor is contiguous and src is a transposed matrix |
50 | // This can be generalized to most copies, but it's trickier |
51 | void copy_same_type_transpose_(Tensor& self, const Tensor& src) { |
52 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
53 | int64_t BLOCK_SZ; |
54 | if (self.scalar_type() == kByte) { |
55 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
56 | BLOCK_SZ = 120; |
57 | } else { |
58 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
59 | BLOCK_SZ = 60; |
60 | } |
61 | Tensor buf = empty({BLOCK_SZ, BLOCK_SZ}, self.options()); |
62 | |
63 | // The code below is implemented with the assumption that sizes are equal |
64 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.sizes().equals(src.sizes())); |
65 | |
66 | AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kHalf, kBool, kBFloat16, kComplexHalf, self.scalar_type(), "copy_" , [&] { |
67 | scalar_t* sp = src.data_ptr<scalar_t>(); |
68 | scalar_t* rp = self.data_ptr<scalar_t>(); |
69 | scalar_t* bp = buf.data_ptr<scalar_t>(); |
70 | |
71 | int64_t NR = src.size(0); |
72 | int64_t NC = src.size(1); |
73 | for (int64_t R = 0; R < NR; R += BLOCK_SZ) { |
74 | for (int64_t C = 0; C < NC; C += BLOCK_SZ) { |
75 | scalar_t* spo = sp + R + C * NR; |
76 | scalar_t* rpo = rp + C + R * NC; |
77 | |
78 | int nr = std::min(NR - R, BLOCK_SZ); |
79 | int nc = std::min(NC - C, BLOCK_SZ); |
80 | |
81 | // 1. copy columns from src to buf |
82 | for (const auto c : c10::irange(nc)) { |
83 | memcpy(bp + c * BLOCK_SZ, spo + c * NR, nr * sizeof(scalar_t)); |
84 | } |
85 | |
86 | // 2. transpose buf in place |
87 | int rc_max = std::max(nr, nc); |
88 | int rc_min = std::min(nr, nc); |
89 | for (const auto r : c10::irange(rc_max)) { |
90 | int end = std::min(r, rc_min); |
91 | for (const auto c : c10::irange(end)) { |
92 | scalar_t tmp = bp[r + BLOCK_SZ * c]; |
93 | bp[r + BLOCK_SZ * c] = bp[r * BLOCK_SZ + c]; |
94 | bp[r * BLOCK_SZ + c] = tmp; |
95 | } |
96 | } |
97 | |
98 | // 3. copy rows from buf to dst |
99 | for (const auto r : c10::irange(nr)) { |
100 | memcpy(rpo + r * NC, bp + r * BLOCK_SZ, nc * sizeof(scalar_t)); |
101 | } |
102 | } |
103 | } |
104 | }); |
105 | } |
106 | |
107 | // Devices directly supported by this copy implementation. Other device types |
108 | // (e.g. XLA) may be supported by overriding copy_ and _copy_from. |
109 | bool is_supported_device(Device device) { |
110 | DeviceType device_type = device.type(); |
111 | return device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan || device_type == kMetal || device_type == kMPS; |
112 | } |
113 | |
114 | } // namespace |
115 | |
116 | namespace at { |
117 | namespace native { |
118 | |
119 | static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) { |
120 | // TODO: this should be handled during dispatch, but that's missing... |
121 | TORCH_CHECK(self.defined(), "self is undefined" ); |
122 | TORCH_CHECK(src.defined(), "src is undefined" ); |
123 | |
124 | // FBGeMM kernel support exists only for the following case, |
125 | // 1. Memory Format for source and destination tensors is contiguous. |
126 | // 2. Device for both the source and destination tensor is CPU. |
127 | // 3. dtype conversion between FP32->FP16 and FP16->FP32. |
128 | // This checks that self.sizes() == src.sizes() because this code path doesn't |
129 | // support broadcasting. This also guards against out of bounds memory access |
130 | // when copying, see fbgemm::Float16ToFloat_ref. |
131 | // https://github.com/pytorch/pytorch/issues/88543 |
132 | #ifdef USE_FBGEMM |
133 | if (((self.dtype() == at::kFloat && src.dtype() == at::kHalf) || |
134 | (self.dtype() == at::kHalf && src.dtype() == at::kFloat)) && |
135 | (self.device().is_cpu() && src.device().is_cpu()) && |
136 | ((self.is_contiguous() && src.is_contiguous()) || |
137 | (self.is_non_overlapping_and_dense() && self.strides() == src.strides())) && |
138 | (self.sizes() == src.sizes())) { |
139 | if (src.dtype() == at::kFloat && self.dtype() == at::kHalf) { |
140 | auto* output_ptr = |
141 | reinterpret_cast<fbgemm::float16*>(self.data_ptr<at::Half>()); |
142 | if (self.numel() < at::internal::GRAIN_SIZE) { |
143 | fbgemm::FloatToFloat16_simd(src.data_ptr<float>(), output_ptr, self.numel()); |
144 | } else { |
145 | at::parallel_for( |
146 | 0, |
147 | self.numel(), |
148 | at::internal::GRAIN_SIZE, |
149 | [&](int64_t begin, int64_t end) { |
150 | fbgemm::FloatToFloat16_simd( |
151 | src.data_ptr<float>() + begin, |
152 | output_ptr + begin, |
153 | end - begin); |
154 | }); |
155 | } |
156 | } else { |
157 | auto in_data = reinterpret_cast<fbgemm::float16*>( |
158 | src.data_ptr<at::Half>()); |
159 | auto* output_ptr = self.data_ptr<float>(); |
160 | if (self.numel() < at::internal::GRAIN_SIZE) { |
161 | fbgemm::Float16ToFloat_simd(in_data, output_ptr, self.numel()); |
162 | } else { |
163 | at::parallel_for( |
164 | 0, |
165 | self.numel(), |
166 | at::internal::GRAIN_SIZE, |
167 | [&](int64_t begin, int64_t end) { |
168 | fbgemm::Float16ToFloat_simd( |
169 | in_data + begin, output_ptr + begin, end - begin); |
170 | }); |
171 | } |
172 | } |
173 | return self; |
174 | } |
175 | #endif |
176 | |
177 | if (self.is_same(src)) { |
178 | return self; |
179 | } |
180 | |
181 | // Copies into meta self are OK and just ignored (similar to inplace) |
182 | if (self.is_meta()) { |
183 | // TODO: need to see if there is extra error checking needed |
184 | return self; |
185 | } |
186 | |
187 | if (src.is_meta()) { |
188 | TORCH_CHECK_NOT_IMPLEMENTED(false, "Cannot copy out of meta tensor; no data!" ) |
189 | } |
190 | |
191 | // Re-dispatch copies when either src or self device not implemented here (e.g. XLA). |
192 | // _copy_from has a proper device dispatch setup. |
193 | // This includes: |
194 | // cpu_tensor.copy_(xla_tensor) => xla_tensor._copy_from(cpu_tensor) |
195 | // xla_tensor.copy_(cpu_tensor) => cpu_tensor._copy_from(xla_tensor) |
196 | // Both the _copy_from calls above will be dispatched to XLA's _copy_from kernels. |
197 | if (!is_supported_device(src.device()) || !is_supported_device(self.device())) { |
198 | at::_copy_from(src, self, non_blocking); |
199 | return self; |
200 | } |
201 | |
202 | if (self.is_quantized() && !src.is_quantized()) { |
203 | return quantized_copy_from_float_(self, src); |
204 | } |
205 | |
206 | if (self.is_quantized() && src.is_quantized()) { |
207 | TORCH_CHECK(self.qscheme() == src.qscheme(), |
208 | "Quantized Copy only works with same qscheme" ); |
209 | TORCH_CHECK(self.scalar_type() == src.scalar_type()); |
210 | set_quantizer_(self, src.quantizer()); |
211 | } |
212 | |
213 | if (!self.is_quantized() && src.is_quantized()) { |
214 | TORCH_CHECK(false, "Copying from quantized Tensor to non-quantized Tensor is not allowed, please use dequantize to get a float Tensor from a quantized Tensor" ); |
215 | } |
216 | |
217 | if (self.device().type() == at::kVulkan || src.device().type() == at::kVulkan) { |
218 | #ifdef USE_VULKAN_API |
219 | return vulkan::ops::copy_(self, src); |
220 | #else |
221 | return at::vulkan::vulkan_copy_(self, src); |
222 | #endif |
223 | } |
224 | |
225 | if (self.device().type() == at::kMetal || src.device().type() == at::kMetal) { |
226 | return at::metal::metal_copy_(self, src); |
227 | } |
228 | |
229 | // Exit early if self and src are views of the same data |
230 | const bool is_same_data = ( |
231 | self.is_alias_of(src) && |
232 | self.storage_offset() == src.storage_offset() && |
233 | self.strides().equals(src.strides()) && |
234 | self.sizes().equals(src.sizes()) && |
235 | self.scalar_type() == src.scalar_type() |
236 | ); |
237 | if (is_same_data) { |
238 | return self; |
239 | } |
240 | |
241 | |
242 | auto iter = TensorIteratorConfig() |
243 | .add_output(self) |
244 | .add_input(src) |
245 | .resize_outputs(false) |
246 | .check_all_same_dtype(false) |
247 | .check_all_same_device(false) |
248 | .build(); |
249 | |
250 | if (iter.numel() == 0) { |
251 | return self; |
252 | } |
253 | |
254 | DeviceType device_type = iter.device_type(0); |
255 | if (iter.device_type(1) == kCUDA) { |
256 | device_type = kCUDA; |
257 | } else if (iter.device_type(1) == kHIP) { |
258 | device_type = kHIP; |
259 | } else if (iter.device_type(1) == kMPS) { |
260 | device_type = kMPS; |
261 | } |
262 | |
263 | // TODO: if we need to, we can also enable this path for quantized tensor |
264 | if (device_type == kCPU && copy_transpose_valid(self, src) && !self.is_quantized()) { |
265 | copy_same_type_transpose_(self, src); |
266 | return self; |
267 | } |
268 | |
269 | #ifdef USE_MPS |
270 | if (self.device().type() == at::kMPS || src.device().type() == at::kMPS) { |
271 | return at::native::mps::mps_copy_(self, src, non_blocking); |
272 | } |
273 | #endif |
274 | |
275 | if(!self.is_complex() && src.is_complex()) { |
276 | TORCH_WARN_ONCE("Casting complex values to real discards the imaginary part" ); |
277 | } |
278 | copy_stub(device_type, iter, non_blocking); |
279 | return self; |
280 | } |
281 | |
282 | Tensor copy(const Tensor& self, const Tensor& src, bool non_blocking) { |
283 | // copy() is the "functional" form of copy_(). It exists so we can properly functionalize copy_(), but: |
284 | // (1) It isn't exposed to the frontend (no python bindings) |
285 | // (2) It isn't exposed to the backend (it's a composite, that decomposes into to() and expand_as() calls. |
286 | auto r = clone_preserve_strides(self); |
287 | r.copy_(src, non_blocking); |
288 | return r; |
289 | } |
290 | |
291 | Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) { |
292 | auto maybe_outnames = namedinference::compute_broadcast_outnames(self, src); |
293 | { |
294 | NoNamesGuard guard; |
295 | if (self._is_zerotensor()) { |
296 | TORCH_CHECK(false, "ZeroTensors are immutable. Please materialize the tensor using `.clone()`, if you want a mutable zero tensor." ); |
297 | } |
298 | if (src._is_zerotensor()) { |
299 | return self.zero_(); |
300 | } |
301 | copy_impl(self, src, non_blocking); |
302 | } |
303 | namedinference::propagate_names_if_nonempty(self, maybe_outnames); |
304 | return self; |
305 | } |
306 | |
307 | void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src) { |
308 | // Called when we are copying into an overlapping index `dst`, but we don't |
309 | // care which writer wins. Hacky but it works. This is only used by |
310 | // CUDA_tensor_apply2 in case that there are write overlaps. |
311 | // FIXME: really, overlapping writes should be illegal/an error in Torch |
312 | auto iter = TensorIteratorConfig() |
313 | .add_output(dst) |
314 | .add_input(src) |
315 | .resize_outputs(false) |
316 | .set_check_mem_overlap(false) |
317 | .check_all_same_dtype(true) |
318 | .check_all_same_device(true) |
319 | .build(); |
320 | copy_stub(iter.device_type(), iter, /*non_blocking=*/false); |
321 | } |
322 | |
323 | DEFINE_DISPATCH(copy_stub); |
324 | |
325 | } // namespace native |
326 | } // namespace at |
327 | |