1 | #include <ATen/MemoryOverlap.h> |
2 | #include <ATen/core/TensorBase.h> |
3 | #include <c10/core/Layout.h> |
4 | #include <c10/util/irange.h> |
5 | |
6 | namespace at { |
7 | |
8 | MemOverlap has_internal_overlap(const TensorBase& tensor) { |
9 | return has_internal_overlap(tensor.unsafeGetTensorImpl()); |
10 | } |
11 | |
12 | MemOverlap has_internal_overlap(TensorImpl* t) { |
13 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t->layout() == kStrided); |
14 | |
15 | if (t->is_non_overlapping_and_dense()) { |
16 | return MemOverlap::No; |
17 | } |
18 | |
19 | auto strides = t->sym_strides(); |
20 | auto sizes = t->sym_sizes(); |
21 | for (const auto i : c10::irange(strides.size())) { |
22 | if (strides[i] == 0 && sizes[i] > 1) { |
23 | return MemOverlap::Yes; |
24 | } |
25 | } |
26 | |
27 | return MemOverlap::TooHard; |
28 | } |
29 | |
30 | void assert_no_internal_overlap(const TensorBase& t) { |
31 | assert_no_internal_overlap(t.unsafeGetTensorImpl()); |
32 | } |
33 | |
34 | void assert_no_internal_overlap(TensorImpl* t) { |
35 | TORCH_CHECK(has_internal_overlap(t) != MemOverlap::Yes, |
36 | "unsupported operation: more than one element of the written-to tensor " |
37 | "refers to a single memory location. Please clone() the tensor before " |
38 | "performing the operation." ); |
39 | } |
40 | |
41 | MemOverlapStatus get_overlap_status(const TensorBase& a, const TensorBase& b) { |
42 | return get_overlap_status(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl()); |
43 | } |
44 | |
45 | MemOverlapStatus get_overlap_status(TensorImpl* a, TensorImpl* b) { |
46 | if (a == b) return MemOverlapStatus::Full; |
47 | if (a->numel() == 0 || b->numel() == 0) { |
48 | return MemOverlapStatus::No; |
49 | } |
50 | if (!a->is_non_overlapping_and_dense() || !b->is_non_overlapping_and_dense()) { |
51 | return MemOverlapStatus::TooHard; |
52 | } |
53 | // Test for storage equality, rather than pointer equality. |
54 | // This reduces precision, but if people are aliasing the |
55 | // same pointer across multiple storages there are many |
56 | // similar situations (e.g., storage().data() == storage().data()+1) |
57 | // which we will miss. |
58 | auto a_storage = a->unsafe_storage(); |
59 | if (a_storage && a_storage.is_alias_of(b->unsafe_storage())) { |
60 | const auto a_begin = static_cast<char*>(a->data()); |
61 | const auto a_end = a_begin + a->numel() * a->itemsize(); |
62 | const auto b_begin = static_cast<char*>(b->data()); |
63 | const auto b_end = b_begin + b->numel() * b->itemsize(); |
64 | |
65 | if (a_begin == b_begin && a_end == b_end) { |
66 | return (a->strides() == b->strides()) ? |
67 | MemOverlapStatus::Full : MemOverlapStatus::Partial; |
68 | } |
69 | if (a_begin < b_end && b_begin < a_end) { |
70 | return MemOverlapStatus::Partial; |
71 | } |
72 | } |
73 | return MemOverlapStatus::No; |
74 | } |
75 | |
76 | void assert_no_partial_overlap(const TensorBase& a, const TensorBase& b) { |
77 | assert_no_partial_overlap(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl()); |
78 | } |
79 | |
80 | void assert_no_partial_overlap(TensorImpl* a, TensorImpl* b) { |
81 | TORCH_CHECK(get_overlap_status(a, b) != MemOverlapStatus::Partial, |
82 | "unsupported operation: some elements of the input tensor and " |
83 | "the written-to tensor refer to a single memory location. " |
84 | "Please clone() the tensor before performing the operation." ); |
85 | } |
86 | |
87 | void assert_no_overlap(const TensorBase& a, const TensorBase& b) { |
88 | assert_no_overlap(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl()); |
89 | } |
90 | |
91 | void assert_no_overlap(TensorImpl* a, TensorImpl* b) { |
92 | const auto lap = get_overlap_status(a, b); |
93 | TORCH_CHECK(lap != MemOverlapStatus::Partial && lap != MemOverlapStatus::Full, |
94 | "unsupported operation: some elements of the input tensor and " |
95 | "the written-to tensor refer to a single memory location. " |
96 | "Please clone() the tensor before performing the operation." ); |
97 | } |
98 | |
99 | } |
100 | |