1#include <ATen/MemoryOverlap.h>
2#include <ATen/core/TensorBase.h>
3#include <c10/core/Layout.h>
4#include <c10/util/irange.h>
5
6namespace at {
7
8MemOverlap has_internal_overlap(const TensorBase& tensor) {
9 return has_internal_overlap(tensor.unsafeGetTensorImpl());
10}
11
12MemOverlap has_internal_overlap(TensorImpl* t) {
13 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t->layout() == kStrided);
14
15 if (t->is_non_overlapping_and_dense()) {
16 return MemOverlap::No;
17 }
18
19 auto strides = t->sym_strides();
20 auto sizes = t->sym_sizes();
21 for (const auto i : c10::irange(strides.size())) {
22 if (strides[i] == 0 && sizes[i] > 1) {
23 return MemOverlap::Yes;
24 }
25 }
26
27 return MemOverlap::TooHard;
28}
29
30void assert_no_internal_overlap(const TensorBase& t) {
31 assert_no_internal_overlap(t.unsafeGetTensorImpl());
32}
33
34void assert_no_internal_overlap(TensorImpl* t) {
35 TORCH_CHECK(has_internal_overlap(t) != MemOverlap::Yes,
36 "unsupported operation: more than one element of the written-to tensor "
37 "refers to a single memory location. Please clone() the tensor before "
38 "performing the operation.");
39}
40
41MemOverlapStatus get_overlap_status(const TensorBase& a, const TensorBase& b) {
42 return get_overlap_status(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
43}
44
45MemOverlapStatus get_overlap_status(TensorImpl* a, TensorImpl* b) {
46 if (a == b) return MemOverlapStatus::Full;
47 if (a->numel() == 0 || b->numel() == 0) {
48 return MemOverlapStatus::No;
49 }
50 if (!a->is_non_overlapping_and_dense() || !b->is_non_overlapping_and_dense()) {
51 return MemOverlapStatus::TooHard;
52 }
53 // Test for storage equality, rather than pointer equality.
54 // This reduces precision, but if people are aliasing the
55 // same pointer across multiple storages there are many
56 // similar situations (e.g., storage().data() == storage().data()+1)
57 // which we will miss.
58 auto a_storage = a->unsafe_storage();
59 if (a_storage && a_storage.is_alias_of(b->unsafe_storage())) {
60 const auto a_begin = static_cast<char*>(a->data());
61 const auto a_end = a_begin + a->numel() * a->itemsize();
62 const auto b_begin = static_cast<char*>(b->data());
63 const auto b_end = b_begin + b->numel() * b->itemsize();
64
65 if (a_begin == b_begin && a_end == b_end) {
66 return (a->strides() == b->strides()) ?
67 MemOverlapStatus::Full : MemOverlapStatus::Partial;
68 }
69 if (a_begin < b_end && b_begin < a_end) {
70 return MemOverlapStatus::Partial;
71 }
72 }
73 return MemOverlapStatus::No;
74}
75
76void assert_no_partial_overlap(const TensorBase& a, const TensorBase& b) {
77 assert_no_partial_overlap(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
78}
79
80void assert_no_partial_overlap(TensorImpl* a, TensorImpl* b) {
81 TORCH_CHECK(get_overlap_status(a, b) != MemOverlapStatus::Partial,
82 "unsupported operation: some elements of the input tensor and "
83 "the written-to tensor refer to a single memory location. "
84 "Please clone() the tensor before performing the operation.");
85}
86
87void assert_no_overlap(const TensorBase& a, const TensorBase& b) {
88 assert_no_overlap(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
89}
90
91void assert_no_overlap(TensorImpl* a, TensorImpl* b) {
92 const auto lap = get_overlap_status(a, b);
93 TORCH_CHECK(lap != MemOverlapStatus::Partial && lap != MemOverlapStatus::Full,
94 "unsupported operation: some elements of the input tensor and "
95 "the written-to tensor refer to a single memory location. "
96 "Please clone() the tensor before performing the operation.");
97}
98
99}
100