1 | #include <ATen/ATen.h> |
2 | #include <ATen/SparseTensorImpl.h> |
3 | #include <ATen/InitialTensorOptions.h> |
4 | #include <ATen/core/LegacyTypeDispatch.h> |
5 | |
6 | namespace at { |
7 | |
8 | namespace { |
9 | DeviceType sparseTensorSetToDeviceType(DispatchKeySet key_set) { |
10 | auto k = c10::highestPriorityBackendTypeId(key_set); |
11 | TORCH_CHECK(c10::toFunctionalityKey(k) == DispatchKey::Sparse, |
12 | "cannot create sparse tensor with non sparse dispatch key " , k); |
13 | return c10::dispatchKeyToDeviceType(k); |
14 | } |
15 | } |
16 | |
17 | |
18 | // An empty dense tensor defaults to a 1-dimensional tensor of size [0] |
19 | // (recall, it is not a 0-dimensional tensor, because such a tensor would |
20 | // a scalar and have one element) |
21 | // |
22 | // Thus, an empty sparse tensor should be a 1-dimensional tensor of size [0]. |
23 | // Furthermore, we have dim == sparse_dim + dense_dim; since this is a sparse |
24 | // tensor, let us say that an empty sparse tensor has sparse_dim == 1 and |
25 | // dense_dim == 0. (There is a degree of freedom here, but given that this |
26 | // is a sparse dimension, it seems reasonable to demand that sparse_dim > 0). |
27 | // |
28 | // This means that we allocate a [1,0] size indices tensor and a [0] size |
29 | // values tensor for such an empty tensor. |
30 | SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta data_type) |
31 | : SparseTensorImpl(key_set, data_type |
32 | , at::empty({1, 0}, at::initialTensorOptions().device(sparseTensorSetToDeviceType(key_set)).dtype(ScalarType::Long)) |
33 | , at::empty({0}, at::initialTensorOptions().device(sparseTensorSetToDeviceType(key_set)).dtype(data_type))) {} |
34 | |
35 | SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta data_type, at::Tensor indices, at::Tensor values) |
36 | : TensorImpl(key_set, data_type, values.device()) |
37 | , sparse_dim_(1) |
38 | , dense_dim_(0) |
39 | , indices_(std::move(indices)) |
40 | , values_(std::move(values)) { |
41 | // we proxy to this constructor so we can initialize the device correctly, but really only indices/values of this shape are allowed. |
42 | AT_ASSERT(indices_.sizes() == IntArrayRef({1, 0})); |
43 | AT_ASSERT(values_.sizes() == IntArrayRef({0})); |
44 | AT_ASSERT(values_.device() == indices_.device()); |
45 | AT_ASSERT(values_.device() == device()); |
46 | |
47 | is_non_overlapping_and_dense_ = false; |
48 | set_storage_access_should_throw(); |
49 | set_custom_sizes_strides(SizesStridesPolicy::CustomStrides); |
50 | } |
51 | |
52 | // Destructor doesn't call release_resources because it's |
53 | // unnecessary; don't forget to change that if needed! |
54 | void SparseTensorImpl::release_resources() { |
55 | TensorImpl::release_resources(); |
56 | values_.reset(); |
57 | indices_.reset(); |
58 | } |
59 | |
60 | void SparseTensorImpl::set_size(int64_t dim, int64_t new_size) { |
61 | AT_ERROR("sparse tensors do not have set_size" ); |
62 | } |
63 | void SparseTensorImpl::set_stride(int64_t dim, int64_t new_stride) { |
64 | AT_ERROR("sparse tensors do not have set_stride" ); |
65 | } |
66 | void SparseTensorImpl::set_storage_offset(int64_t storage_offset) { |
67 | AT_ERROR("sparse tensors do not have set_storage_offset" ); |
68 | } |
69 | #ifdef DEBUG |
70 | bool SparseTensorImpl::has_storage() const { |
71 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!storage_, "SparseTensorImpl assumes that storage_ is never set" ); |
72 | return false; |
73 | } |
74 | #endif |
75 | |
76 | const char* SparseTensorImpl::tensorimpl_type_name() const { |
77 | return "SparseTensorImpl" ; |
78 | } |
79 | |
80 | void SparseTensorImpl::set_indices_and_values_unsafe(const Tensor& indices, const Tensor& values) { |
81 | TORCH_CHECK(allow_tensor_metadata_change(), "set_indices_and_values_unsafe " , err_msg_tensor_metadata_change_not_allowed); |
82 | |
83 | TORCH_CHECK(!indices.is_sparse(), "expected indices to be a dense tensor, but got indices of layout " , indices.layout()); |
84 | TORCH_CHECK(!values.is_sparse(), "expected values to be a dense tensor, but got values of layout " , values.layout()); |
85 | |
86 | TORCH_CHECK(values.device().type() == device().type(), "device type of values (" , values.device().type(), ") must match device type of device().type()" , device().type(), ")" ); |
87 | TORCH_CHECK(values.scalar_type() == typeMetaToScalarType(dtype()), "dtype of values (" , values.scalar_type(), ") must match dtype of sparse tensor (" , typeMetaToScalarType(dtype()), ")" ); |
88 | TORCH_CHECK(indices.scalar_type() == kLong, "indices must be an int64 tensor" ); |
89 | TORCH_CHECK(indices.options().backend() == values.options().backend(), "backend of indices (" , indices.options().backend(), ") must match backend of values (" , values.options().backend(), ")" ); |
90 | TORCH_CHECK(!indices.is_cuda() || indices.get_device() == values.get_device(), "device of indices (" , indices.get_device(), ") must match device of values (" , values.get_device(), ")" ); |
91 | |
92 | TORCH_CHECK(indices.dim() == 2, "indices must be sparse_dim x nnz, but got: " , indices.sym_sizes()); |
93 | TORCH_CHECK(indices.sym_size(1) == values.sym_size(0), "indices and values must have same nnz, but got nnz from indices: " , indices.sym_size(1), ", nnz from values: " , values.sym_size(0)); |
94 | TORCH_CHECK(indices.sym_size(0) == sparse_dim_, "indices has incorrect first dimension, expected " , sparse_dim_, ", got " , indices.sym_size(0)); |
95 | TORCH_CHECK(values.dim() == dense_dim_ + 1, "values has incorrect number of dimensions, expected " , dense_dim_ + 1, ", got " , values.dim()); |
96 | |
97 | auto dense_size_original = sym_sizes().slice(sparse_dim_); |
98 | std::vector<c10::SymInt> expected_values_size_vec = {values.sym_size(0)}; |
99 | expected_values_size_vec.insert(expected_values_size_vec.end(), dense_size_original.begin(), dense_size_original.end()); |
100 | SymIntArrayRef expected_values_size(expected_values_size_vec); |
101 | auto new_values_size = values.sym_sizes(); |
102 | TORCH_CHECK( |
103 | std::equal(expected_values_size.begin(), expected_values_size.end(), new_values_size.begin()), |
104 | "values has incorrect size, expected " , expected_values_size, ", got " , new_values_size |
105 | ); |
106 | |
107 | indices_ = indices; |
108 | values_ = values; |
109 | AT_ASSERT(device() == values_.device()); |
110 | AT_ASSERT(values_.device() == indices_.device()); |
111 | |
112 | coalesced_ = sym_nnz() < 2; |
113 | } |
114 | |
115 | |
116 | } // namespace at |
117 | |