1#include <ATen/ATen.h>
2#include <ATen/InitialTensorOptions.h>
3#include <ATen/SparseCsrTensorImpl.h>
4#include <ATen/SparseCsrTensorUtils.h>
5#include <ATen/SparseTensorImpl.h>
6#include <ATen/SparseTensorUtils.h>
7#include <ATen/core/LegacyTypeDispatch.h>
8#include <ATen/native/Resize.h>
9
10namespace at {
11
12SparseCsrTensorImpl::SparseCsrTensorImpl(
13 at::DispatchKeySet key_set,
14 at::Device device,
15 at::Layout layout,
16 const caffe2::TypeMeta data_type)
17 : SparseCsrTensorImpl(
18 key_set,
19 data_type,
20 at::empty(
21 {0},
22 at::initialTensorOptions()
23 .device(device)
24 .dtype(ScalarType::Int)) // crow_indices
25 ,
26 at::empty(
27 {0},
28 at::initialTensorOptions()
29 .device(device)
30 .dtype(ScalarType::Int)) // col_indices
31 ,
32 at::empty(
33 {0},
34 at::initialTensorOptions()
35 .device(device)
36 .dtype(data_type)) // values
37 ,
38 layout
39 ) {}
40
41SparseCsrTensorImpl::SparseCsrTensorImpl(
42 at::DispatchKeySet key_set,
43 const caffe2::TypeMeta data_type,
44 at::Tensor crow_indices,
45 at::Tensor col_indices,
46 at::Tensor values,
47 at::Layout layout)
48 : TensorImpl(key_set, data_type, values.device()),
49 crow_indices_(std::move(crow_indices)),
50 col_indices_(std::move(col_indices)),
51 values_(std::move(values)),
52 layout_(layout) {
53 // https://pytorch.org/blog/pytorch-feature-classification-changes/#beta
54 TORCH_WARN_ONCE("Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensor support is in beta state. "
55 "If you miss a functionality in the sparse tensor support, please submit a feature request "
56 "to https://github.com/pytorch/pytorch/issues.");
57
58 TORCH_INTERNAL_ASSERT(((key_set.has(DispatchKey::SparseCsrCPU) && device().type() == kCPU)
59 || (key_set.has(DispatchKey::SparseCsrCUDA) && device().type() == kCUDA)),
60 "Inconsistent key_set (=", key_set, ") and device (=", device(), ")");
61
62 set_storage_access_should_throw();
63 is_non_overlapping_and_dense_ = false;
64 set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
65 // TODO: If this check ever shows up as a bottleneck, which is unlikely given that
66 // comparing devices only involves comparing the type and index (two integers), we
67 // can move this to a DEBUG only assert. Until then this confirms and maintains a
68 // crucial invariance.
69 TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and ",
70 at::sparse_csr::compressedIndicesName(layout_), " need to be on the same device.");
71 TORCH_CHECK(values_.device() == col_indices_.device(), "Values and ",
72 at::sparse_csr::plainIndicesName(layout_), " need to be on the same device.");
73 TORCH_INTERNAL_ASSERT(values_.device() == device(),
74 "Values and compressed sparse tensor instance need to have the same device.");
75}
76
77const char* SparseCsrTensorImpl::tensorimpl_type_name() const {
78 return "SparseCsrTensorImpl";
79}
80
81void SparseCsrTensorImpl::resize_(int64_t nnz, IntArrayRef size) {
82 TORCH_CHECK(
83 !has_symbolic_sizes_strides_,
84 "resize_ called on tensor with symbolic shape")
85 auto rows = size[size.size() - 2];
86 auto cols = size[size.size() - 1];
87 auto old_crow_indices_size = crow_indices_.size(-1);
88
89 auto new_crow_indices_size = DimVector(size.slice(0, size.size() - 2));
90 new_crow_indices_size.push_back(rows + 1);
91 crow_indices_.resize_(new_crow_indices_size);
92 if (rows + 1 >= old_crow_indices_size) {
93 crow_indices_.narrow(-1, old_crow_indices_size, rows + 1 - old_crow_indices_size).fill_(nnz);
94 } else {
95 crow_indices_.narrow(-1, rows, 1).fill_(std::min<int64_t>(nnz, rows*cols));
96 }
97 auto col_indices_values_size = DimVector(size.slice(0, size.size() - 2));
98 col_indices_values_size.push_back(std::min<int64_t>(nnz, rows*cols));
99 col_indices_.resize_(col_indices_values_size);
100 values_.resize_(col_indices_values_size);
101 sizes_and_strides_.set_sizes(size);
102 refresh_numel();
103}
104
105void SparseCsrTensorImpl::resize_and_clear_(int64_t sparse_dim, IntArrayRef size) {
106 TORCH_CHECK(
107 !has_symbolic_sizes_strides_,
108 "resize_and_clear_ called on tensor with symbolic shape");
109 TORCH_CHECK(sparse_dim >= 2, "resize_and_clear_ sparse dimensionality must be at least 2, got ", sparse_dim);
110 TORCH_CHECK(static_cast<int64_t>(size.size()) >= sparse_dim, "resize_and_clear_ size length must be at least sparse dimensionality (=",
111 sparse_dim, "), got ", size.size());
112 auto batch_dim = sparse_dim - 2;
113 auto batchsize = size.slice(0, batch_dim);
114 auto densesize = size.slice(batch_dim + 2, size.size() - batch_dim - 2);
115
116 auto values_size = DimVector(batchsize);
117 values_size.push_back(0); // nse
118 values_size.append(densesize.begin(), densesize.end());
119
120 auto col_indices_size = DimVector(batchsize);
121 col_indices_size.push_back(0); // nse
122
123 auto n_compressed_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout_, "resize_and_clear_",
124 [&] () -> int64_t { return size[batch_dim]; },
125 [&] () -> int64_t { return size[batch_dim + 1]; }
126 );
127 AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout_,
128 "resize_and_clear_",
129 [] () {},
130 [&] () {
131 auto blocksize = this->values_.sizes().slice(this->batch_dim() + 1, 2);
132 values_size.append(blocksize.begin(), blocksize.end());
133 n_compressed_indices /= blocksize[(the_layout == kSparseBsr ? 0 : 1)];
134 });
135 auto crow_indices_size = DimVector(batchsize);
136 crow_indices_size.push_back(n_compressed_indices + 1);
137
138 crow_indices_.resize_(crow_indices_size);
139 crow_indices_.zero_();
140 col_indices_.resize_(col_indices_size);
141 values_.resize_(values_size);
142 sizes_and_strides_.set_sizes(size);
143 refresh_numel();
144}
145
146void SparseCsrTensorImpl::resize_as_sparse_compressed_tensor_(
147 const Tensor& src) {
148 TORCH_CHECK(
149 !has_symbolic_sizes_strides_,
150 "resize_as_sparse_compressed_tensor_ called on tensor with symbolic shape");
151
152 // We cannot resize as other layout and preserve the invariants for self
153 // layout
154 TORCH_CHECK(
155 src.layout() == layout_,
156 "resize_as_sparse_compressed_tensor_: self and src must have the same layout, but got: self (",
157 layout_,
158 ") and source (",
159 src.layout(),
160 ")");
161
162 Tensor compressed_indices;
163 Tensor plain_indices;
164 std::tie(compressed_indices, plain_indices) =
165 sparse_csr::getCompressedPlainIndices(src);
166 // reuse self indices storage
167 if (crow_indices_.sizes() != compressed_indices.sizes()) {
168 crow_indices_.resize_as_(compressed_indices);
169 }
170 if (col_indices_.sizes() != plain_indices.sizes()) {
171 col_indices_.resize_as_(plain_indices);
172 }
173 // Update indices data to ensure result is valid under invariants check
174 if ((sizes() != src.sizes()) || (dense_dim() != src.dense_dim())) {
175 crow_indices_.copy_(compressed_indices);
176 col_indices_.copy_(plain_indices);
177 }
178 // Reuse values storage
179 if (values_.sizes() != src.values().sizes()) {
180 values_.resize_as_(src.values());
181 }
182 sizes_and_strides_.set_sizes(src.sizes());
183 refresh_numel();
184}
185
186void SparseCsrTensorImpl::set_member_tensors(
187 const Tensor& crow_indices,
188 const Tensor& col_indices,
189 const Tensor& values,
190 IntArrayRef size) {
191 TORCH_CHECK(
192 !has_symbolic_sizes_strides_,
193 "set_member_tensors called on tensor with symbolic shape");
194
195 // CSR Type Invariants
196 TORCH_CHECK(
197 values.scalar_type() == typeMetaToScalarType(dtype()),
198 "dtype of values (",
199 values.scalar_type(),
200 ") must match dtype of sparse tensor (",
201 typeMetaToScalarType(dtype()),
202 ")");
203 crow_indices_ = crow_indices;
204 col_indices_ = col_indices;
205 values_ = values;
206
207 sizes_and_strides_.set_sizes(size);
208 refresh_numel();
209 // TODO: If this check ever shows up as a bottleneck, which is unlikely given that
210 // comparing devices only involves comparing the type and index (two integers), we
211 // can move this to a DEBUG only assert. Until then this confirms and maintains a
212 // crucial invariance.
213 TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and ",
214 at::sparse_csr::compressedIndicesName(layout_), " need to be on the same device.");
215 TORCH_CHECK(values_.device() == col_indices_.device(), "Values and ",
216 at::sparse_csr::plainIndicesName(layout_), " need to be on the same device.");
217 TORCH_CHECK(values_.device() == device(),
218 "Values and compressed tensor instance need to be on the same device.");
219}
220
221IntArrayRef SparseCsrTensorImpl::strides_custom() const {
222 TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have strides");
223}
224SymIntArrayRef SparseCsrTensorImpl::sym_strides_custom() const {
225 TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have strides");
226}
227void SparseCsrTensorImpl::set_size(int64_t dim, int64_t new_size) {
228 TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_size.");
229}
230void SparseCsrTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
231 TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_stride.");
232}
233void SparseCsrTensorImpl::set_storage_offset(int64_t storage_offset) {
234 TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_storage_offset.");
235}
236bool SparseCsrTensorImpl::is_contiguous_custom(MemoryFormat) const {
237 TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have is_contiguous");
238}
239
240} // namespace at
241