1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2#include <ATen/native/Copy.h>
3
4#include <ATen/core/Tensor.h>
5#include <ATen/Dispatch.h>
6#include <ATen/FunctionalTensorWrapper.h>
7#include <ATen/TensorIterator.h>
8#include <ATen/native/quantized/Copy.h>
9#include <ATen/native/mps/Copy.h>
10#include <ATen/native/vulkan/ops/Copy.h>
11#include <ATen/native/TensorShape.h>
12#include <ATen/quantized/Quantizer.h>
13#include <ATen/vulkan/Context.h>
14#include <ATen/metal/Context.h>
15#include <ATen/NamedTensorUtils.h>
16#include <ATen/Parallel.h>
17#include <c10/util/irange.h>
18
19#ifndef AT_PER_OPERATOR_HEADERS
20#include <ATen/Functions.h>
21#include <ATen/NativeFunctions.h>
22#else
23#include <ATen/ops/_copy_from.h>
24#include <ATen/ops/copy_native.h>
25#include <ATen/ops/empty.h>
26#include <ATen/ops/expand_copy.h>
27#endif
28
29#ifdef USE_FBGEMM
30#include <fbgemm/Fbgemm.h>
31#include <fbgemm/FbgemmConvert.h>
32#endif
33
34namespace {
35
36using namespace at;
37
38bool copy_transpose_valid(const Tensor& self, const Tensor& src) {
39 const int MIN_SZ = 60 * 60;
40 return self.is_contiguous() && src.numel() != 0 && src.dim() == 2 &&
41 src.stride(0) == 1 && src.stride(1) == src.size(0) &&
42 self.scalar_type() == src.scalar_type() &&
43 self.sizes().equals(src.sizes()) &&
44 self.is_neg() == src.is_neg() &&
45 self.is_conj() == src.is_conj() &&
46 self.numel() >= MIN_SZ;
47}
48
49// special case copy where tensor is contiguous and src is a transposed matrix
50// This can be generalized to most copies, but it's trickier
51void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
52 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
53 int64_t BLOCK_SZ;
54 if (self.scalar_type() == kByte) {
55 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
56 BLOCK_SZ = 120;
57 } else {
58 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
59 BLOCK_SZ = 60;
60 }
61 Tensor buf = empty({BLOCK_SZ, BLOCK_SZ}, self.options());
62
63 // The code below is implemented with the assumption that sizes are equal
64 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.sizes().equals(src.sizes()));
65
66 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kHalf, kBool, kBFloat16, kComplexHalf, self.scalar_type(), "copy_", [&] {
67 scalar_t* sp = src.data_ptr<scalar_t>();
68 scalar_t* rp = self.data_ptr<scalar_t>();
69 scalar_t* bp = buf.data_ptr<scalar_t>();
70
71 int64_t NR = src.size(0);
72 int64_t NC = src.size(1);
73 for (int64_t R = 0; R < NR; R += BLOCK_SZ) {
74 for (int64_t C = 0; C < NC; C += BLOCK_SZ) {
75 scalar_t* spo = sp + R + C * NR;
76 scalar_t* rpo = rp + C + R * NC;
77
78 int nr = std::min(NR - R, BLOCK_SZ);
79 int nc = std::min(NC - C, BLOCK_SZ);
80
81 // 1. copy columns from src to buf
82 for (const auto c : c10::irange(nc)) {
83 memcpy(bp + c * BLOCK_SZ, spo + c * NR, nr * sizeof(scalar_t));
84 }
85
86 // 2. transpose buf in place
87 int rc_max = std::max(nr, nc);
88 int rc_min = std::min(nr, nc);
89 for (const auto r : c10::irange(rc_max)) {
90 int end = std::min(r, rc_min);
91 for (const auto c : c10::irange(end)) {
92 scalar_t tmp = bp[r + BLOCK_SZ * c];
93 bp[r + BLOCK_SZ * c] = bp[r * BLOCK_SZ + c];
94 bp[r * BLOCK_SZ + c] = tmp;
95 }
96 }
97
98 // 3. copy rows from buf to dst
99 for (const auto r : c10::irange(nr)) {
100 memcpy(rpo + r * NC, bp + r * BLOCK_SZ, nc * sizeof(scalar_t));
101 }
102 }
103 }
104 });
105}
106
107// Devices directly supported by this copy implementation. Other device types
108// (e.g. XLA) may be supported by overriding copy_ and _copy_from.
109bool is_supported_device(Device device) {
110 DeviceType device_type = device.type();
111 return device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan || device_type == kMetal || device_type == kMPS;
112}
113
114} // namespace
115
116namespace at {
117namespace native {
118
119static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) {
120 // TODO: this should be handled during dispatch, but that's missing...
121 TORCH_CHECK(self.defined(), "self is undefined");
122 TORCH_CHECK(src.defined(), "src is undefined");
123
124 // FBGeMM kernel support exists only for the following case,
125 // 1. Memory Format for source and destination tensors is contiguous.
126 // 2. Device for both the source and destination tensor is CPU.
127 // 3. dtype conversion between FP32->FP16 and FP16->FP32.
128 // This checks that self.sizes() == src.sizes() because this code path doesn't
129 // support broadcasting. This also guards against out of bounds memory access
130 // when copying, see fbgemm::Float16ToFloat_ref.
131 // https://github.com/pytorch/pytorch/issues/88543
132 #ifdef USE_FBGEMM
133 if (((self.dtype() == at::kFloat && src.dtype() == at::kHalf) ||
134 (self.dtype() == at::kHalf && src.dtype() == at::kFloat)) &&
135 (self.device().is_cpu() && src.device().is_cpu()) &&
136 ((self.is_contiguous() && src.is_contiguous()) ||
137 (self.is_non_overlapping_and_dense() && self.strides() == src.strides())) &&
138 (self.sizes() == src.sizes())) {
139 if (src.dtype() == at::kFloat && self.dtype() == at::kHalf) {
140 auto* output_ptr =
141 reinterpret_cast<fbgemm::float16*>(self.data_ptr<at::Half>());
142 if (self.numel() < at::internal::GRAIN_SIZE) {
143 fbgemm::FloatToFloat16_simd(src.data_ptr<float>(), output_ptr, self.numel());
144 } else {
145 at::parallel_for(
146 0,
147 self.numel(),
148 at::internal::GRAIN_SIZE,
149 [&](int64_t begin, int64_t end) {
150 fbgemm::FloatToFloat16_simd(
151 src.data_ptr<float>() + begin,
152 output_ptr + begin,
153 end - begin);
154 });
155 }
156 } else {
157 auto in_data = reinterpret_cast<fbgemm::float16*>(
158 src.data_ptr<at::Half>());
159 auto* output_ptr = self.data_ptr<float>();
160 if (self.numel() < at::internal::GRAIN_SIZE) {
161 fbgemm::Float16ToFloat_simd(in_data, output_ptr, self.numel());
162 } else {
163 at::parallel_for(
164 0,
165 self.numel(),
166 at::internal::GRAIN_SIZE,
167 [&](int64_t begin, int64_t end) {
168 fbgemm::Float16ToFloat_simd(
169 in_data + begin, output_ptr + begin, end - begin);
170 });
171 }
172 }
173 return self;
174 }
175 #endif
176
177 if (self.is_same(src)) {
178 return self;
179 }
180
181 // Copies into meta self are OK and just ignored (similar to inplace)
182 if (self.is_meta()) {
183 // TODO: need to see if there is extra error checking needed
184 return self;
185 }
186
187 if (src.is_meta()) {
188 TORCH_CHECK_NOT_IMPLEMENTED(false, "Cannot copy out of meta tensor; no data!")
189 }
190
191 // Re-dispatch copies when either src or self device not implemented here (e.g. XLA).
192 // _copy_from has a proper device dispatch setup.
193 // This includes:
194 // cpu_tensor.copy_(xla_tensor) => xla_tensor._copy_from(cpu_tensor)
195 // xla_tensor.copy_(cpu_tensor) => cpu_tensor._copy_from(xla_tensor)
196 // Both the _copy_from calls above will be dispatched to XLA's _copy_from kernels.
197 if (!is_supported_device(src.device()) || !is_supported_device(self.device())) {
198 at::_copy_from(src, self, non_blocking);
199 return self;
200 }
201
202 if (self.is_quantized() && !src.is_quantized()) {
203 return quantized_copy_from_float_(self, src);
204 }
205
206 if (self.is_quantized() && src.is_quantized()) {
207 TORCH_CHECK(self.qscheme() == src.qscheme(),
208 "Quantized Copy only works with same qscheme");
209 TORCH_CHECK(self.scalar_type() == src.scalar_type());
210 set_quantizer_(self, src.quantizer());
211 }
212
213 if (!self.is_quantized() && src.is_quantized()) {
214 TORCH_CHECK(false, "Copying from quantized Tensor to non-quantized Tensor is not allowed, please use dequantize to get a float Tensor from a quantized Tensor");
215 }
216
217 if (self.device().type() == at::kVulkan || src.device().type() == at::kVulkan) {
218 #ifdef USE_VULKAN_API
219 return vulkan::ops::copy_(self, src);
220 #else
221 return at::vulkan::vulkan_copy_(self, src);
222 #endif
223 }
224
225 if (self.device().type() == at::kMetal || src.device().type() == at::kMetal) {
226 return at::metal::metal_copy_(self, src);
227 }
228
229 // Exit early if self and src are views of the same data
230 const bool is_same_data = (
231 self.is_alias_of(src) &&
232 self.storage_offset() == src.storage_offset() &&
233 self.strides().equals(src.strides()) &&
234 self.sizes().equals(src.sizes()) &&
235 self.scalar_type() == src.scalar_type()
236 );
237 if (is_same_data) {
238 return self;
239 }
240
241
242 auto iter = TensorIteratorConfig()
243 .add_output(self)
244 .add_input(src)
245 .resize_outputs(false)
246 .check_all_same_dtype(false)
247 .check_all_same_device(false)
248 .build();
249
250 if (iter.numel() == 0) {
251 return self;
252 }
253
254 DeviceType device_type = iter.device_type(0);
255 if (iter.device_type(1) == kCUDA) {
256 device_type = kCUDA;
257 } else if (iter.device_type(1) == kHIP) {
258 device_type = kHIP;
259 } else if (iter.device_type(1) == kMPS) {
260 device_type = kMPS;
261 }
262
263 // TODO: if we need to, we can also enable this path for quantized tensor
264 if (device_type == kCPU && copy_transpose_valid(self, src) && !self.is_quantized()) {
265 copy_same_type_transpose_(self, src);
266 return self;
267 }
268
269#ifdef USE_MPS
270 if (self.device().type() == at::kMPS || src.device().type() == at::kMPS) {
271 return at::native::mps::mps_copy_(self, src, non_blocking);
272 }
273#endif
274
275 if(!self.is_complex() && src.is_complex()) {
276 TORCH_WARN_ONCE("Casting complex values to real discards the imaginary part");
277 }
278 copy_stub(device_type, iter, non_blocking);
279 return self;
280}
281
282Tensor copy(const Tensor& self, const Tensor& src, bool non_blocking) {
283 // copy() is the "functional" form of copy_(). It exists so we can properly functionalize copy_(), but:
284 // (1) It isn't exposed to the frontend (no python bindings)
285 // (2) It isn't exposed to the backend (it's a composite, that decomposes into to() and expand_as() calls.
286 auto r = clone_preserve_strides(self);
287 r.copy_(src, non_blocking);
288 return r;
289}
290
291Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) {
292 auto maybe_outnames = namedinference::compute_broadcast_outnames(self, src);
293 {
294 NoNamesGuard guard;
295 if (self._is_zerotensor()) {
296 TORCH_CHECK(false, "ZeroTensors are immutable. Please materialize the tensor using `.clone()`, if you want a mutable zero tensor.");
297 }
298 if (src._is_zerotensor()) {
299 return self.zero_();
300 }
301 copy_impl(self, src, non_blocking);
302 }
303 namedinference::propagate_names_if_nonempty(self, maybe_outnames);
304 return self;
305}
306
307void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src) {
308 // Called when we are copying into an overlapping index `dst`, but we don't
309 // care which writer wins. Hacky but it works. This is only used by
310 // CUDA_tensor_apply2 in case that there are write overlaps.
311 // FIXME: really, overlapping writes should be illegal/an error in Torch
312 auto iter = TensorIteratorConfig()
313 .add_output(dst)
314 .add_input(src)
315 .resize_outputs(false)
316 .set_check_mem_overlap(false)
317 .check_all_same_dtype(true)
318 .check_all_same_device(true)
319 .build();
320 copy_stub(iter.device_type(), iter, /*non_blocking=*/false);
321}
322
323DEFINE_DISPATCH(copy_stub);
324
325} // namespace native
326} // namespace at
327