1 | #include <ATen/ATen.h> |
2 | #include <ATen/InitialTensorOptions.h> |
3 | #include <ATen/SparseCsrTensorImpl.h> |
4 | #include <ATen/SparseCsrTensorUtils.h> |
5 | #include <ATen/SparseTensorImpl.h> |
6 | #include <ATen/SparseTensorUtils.h> |
7 | #include <ATen/core/LegacyTypeDispatch.h> |
8 | #include <ATen/native/Resize.h> |
9 | |
10 | namespace at { |
11 | |
12 | SparseCsrTensorImpl::SparseCsrTensorImpl( |
13 | at::DispatchKeySet key_set, |
14 | at::Device device, |
15 | at::Layout layout, |
16 | const caffe2::TypeMeta data_type) |
17 | : SparseCsrTensorImpl( |
18 | key_set, |
19 | data_type, |
20 | at::empty( |
21 | {0}, |
22 | at::initialTensorOptions() |
23 | .device(device) |
24 | .dtype(ScalarType::Int)) // crow_indices |
25 | , |
26 | at::empty( |
27 | {0}, |
28 | at::initialTensorOptions() |
29 | .device(device) |
30 | .dtype(ScalarType::Int)) // col_indices |
31 | , |
32 | at::empty( |
33 | {0}, |
34 | at::initialTensorOptions() |
35 | .device(device) |
36 | .dtype(data_type)) // values |
37 | , |
38 | layout |
39 | ) {} |
40 | |
41 | SparseCsrTensorImpl::SparseCsrTensorImpl( |
42 | at::DispatchKeySet key_set, |
43 | const caffe2::TypeMeta data_type, |
44 | at::Tensor crow_indices, |
45 | at::Tensor col_indices, |
46 | at::Tensor values, |
47 | at::Layout layout) |
48 | : TensorImpl(key_set, data_type, values.device()), |
49 | crow_indices_(std::move(crow_indices)), |
50 | col_indices_(std::move(col_indices)), |
51 | values_(std::move(values)), |
52 | layout_(layout) { |
53 | // https://pytorch.org/blog/pytorch-feature-classification-changes/#beta |
54 | TORCH_WARN_ONCE("Sparse " , at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensor support is in beta state. " |
55 | "If you miss a functionality in the sparse tensor support, please submit a feature request " |
56 | "to https://github.com/pytorch/pytorch/issues." ); |
57 | |
58 | TORCH_INTERNAL_ASSERT(((key_set.has(DispatchKey::SparseCsrCPU) && device().type() == kCPU) |
59 | || (key_set.has(DispatchKey::SparseCsrCUDA) && device().type() == kCUDA)), |
60 | "Inconsistent key_set (=" , key_set, ") and device (=" , device(), ")" ); |
61 | |
62 | set_storage_access_should_throw(); |
63 | is_non_overlapping_and_dense_ = false; |
64 | set_custom_sizes_strides(SizesStridesPolicy::CustomStrides); |
65 | // TODO: If this check ever shows up as a bottleneck, which is unlikely given that |
66 | // comparing devices only involves comparing the type and index (two integers), we |
67 | // can move this to a DEBUG only assert. Until then this confirms and maintains a |
68 | // crucial invariance. |
69 | TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and " , |
70 | at::sparse_csr::compressedIndicesName(layout_), " need to be on the same device." ); |
71 | TORCH_CHECK(values_.device() == col_indices_.device(), "Values and " , |
72 | at::sparse_csr::plainIndicesName(layout_), " need to be on the same device." ); |
73 | TORCH_INTERNAL_ASSERT(values_.device() == device(), |
74 | "Values and compressed sparse tensor instance need to have the same device." ); |
75 | } |
76 | |
77 | const char* SparseCsrTensorImpl::tensorimpl_type_name() const { |
78 | return "SparseCsrTensorImpl" ; |
79 | } |
80 | |
81 | void SparseCsrTensorImpl::resize_(int64_t nnz, IntArrayRef size) { |
82 | TORCH_CHECK( |
83 | !has_symbolic_sizes_strides_, |
84 | "resize_ called on tensor with symbolic shape" ) |
85 | auto rows = size[size.size() - 2]; |
86 | auto cols = size[size.size() - 1]; |
87 | auto old_crow_indices_size = crow_indices_.size(-1); |
88 | |
89 | auto new_crow_indices_size = DimVector(size.slice(0, size.size() - 2)); |
90 | new_crow_indices_size.push_back(rows + 1); |
91 | crow_indices_.resize_(new_crow_indices_size); |
92 | if (rows + 1 >= old_crow_indices_size) { |
93 | crow_indices_.narrow(-1, old_crow_indices_size, rows + 1 - old_crow_indices_size).fill_(nnz); |
94 | } else { |
95 | crow_indices_.narrow(-1, rows, 1).fill_(std::min<int64_t>(nnz, rows*cols)); |
96 | } |
97 | auto col_indices_values_size = DimVector(size.slice(0, size.size() - 2)); |
98 | col_indices_values_size.push_back(std::min<int64_t>(nnz, rows*cols)); |
99 | col_indices_.resize_(col_indices_values_size); |
100 | values_.resize_(col_indices_values_size); |
101 | sizes_and_strides_.set_sizes(size); |
102 | refresh_numel(); |
103 | } |
104 | |
105 | void SparseCsrTensorImpl::resize_and_clear_(int64_t sparse_dim, IntArrayRef size) { |
106 | TORCH_CHECK( |
107 | !has_symbolic_sizes_strides_, |
108 | "resize_and_clear_ called on tensor with symbolic shape" ); |
109 | TORCH_CHECK(sparse_dim >= 2, "resize_and_clear_ sparse dimensionality must be at least 2, got " , sparse_dim); |
110 | TORCH_CHECK(static_cast<int64_t>(size.size()) >= sparse_dim, "resize_and_clear_ size length must be at least sparse dimensionality (=" , |
111 | sparse_dim, "), got " , size.size()); |
112 | auto batch_dim = sparse_dim - 2; |
113 | auto batchsize = size.slice(0, batch_dim); |
114 | auto densesize = size.slice(batch_dim + 2, size.size() - batch_dim - 2); |
115 | |
116 | auto values_size = DimVector(batchsize); |
117 | values_size.push_back(0); // nse |
118 | values_size.append(densesize.begin(), densesize.end()); |
119 | |
120 | auto col_indices_size = DimVector(batchsize); |
121 | col_indices_size.push_back(0); // nse |
122 | |
123 | auto n_compressed_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout_, "resize_and_clear_" , |
124 | [&] () -> int64_t { return size[batch_dim]; }, |
125 | [&] () -> int64_t { return size[batch_dim + 1]; } |
126 | ); |
127 | AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout_, |
128 | "resize_and_clear_" , |
129 | [] () {}, |
130 | [&] () { |
131 | auto blocksize = this->values_.sizes().slice(this->batch_dim() + 1, 2); |
132 | values_size.append(blocksize.begin(), blocksize.end()); |
133 | n_compressed_indices /= blocksize[(the_layout == kSparseBsr ? 0 : 1)]; |
134 | }); |
135 | auto crow_indices_size = DimVector(batchsize); |
136 | crow_indices_size.push_back(n_compressed_indices + 1); |
137 | |
138 | crow_indices_.resize_(crow_indices_size); |
139 | crow_indices_.zero_(); |
140 | col_indices_.resize_(col_indices_size); |
141 | values_.resize_(values_size); |
142 | sizes_and_strides_.set_sizes(size); |
143 | refresh_numel(); |
144 | } |
145 | |
146 | void SparseCsrTensorImpl::resize_as_sparse_compressed_tensor_( |
147 | const Tensor& src) { |
148 | TORCH_CHECK( |
149 | !has_symbolic_sizes_strides_, |
150 | "resize_as_sparse_compressed_tensor_ called on tensor with symbolic shape" ); |
151 | |
152 | // We cannot resize as other layout and preserve the invariants for self |
153 | // layout |
154 | TORCH_CHECK( |
155 | src.layout() == layout_, |
156 | "resize_as_sparse_compressed_tensor_: self and src must have the same layout, but got: self (" , |
157 | layout_, |
158 | ") and source (" , |
159 | src.layout(), |
160 | ")" ); |
161 | |
162 | Tensor compressed_indices; |
163 | Tensor plain_indices; |
164 | std::tie(compressed_indices, plain_indices) = |
165 | sparse_csr::getCompressedPlainIndices(src); |
166 | // reuse self indices storage |
167 | if (crow_indices_.sizes() != compressed_indices.sizes()) { |
168 | crow_indices_.resize_as_(compressed_indices); |
169 | } |
170 | if (col_indices_.sizes() != plain_indices.sizes()) { |
171 | col_indices_.resize_as_(plain_indices); |
172 | } |
173 | // Update indices data to ensure result is valid under invariants check |
174 | if ((sizes() != src.sizes()) || (dense_dim() != src.dense_dim())) { |
175 | crow_indices_.copy_(compressed_indices); |
176 | col_indices_.copy_(plain_indices); |
177 | } |
178 | // Reuse values storage |
179 | if (values_.sizes() != src.values().sizes()) { |
180 | values_.resize_as_(src.values()); |
181 | } |
182 | sizes_and_strides_.set_sizes(src.sizes()); |
183 | refresh_numel(); |
184 | } |
185 | |
186 | void SparseCsrTensorImpl::set_member_tensors( |
187 | const Tensor& crow_indices, |
188 | const Tensor& col_indices, |
189 | const Tensor& values, |
190 | IntArrayRef size) { |
191 | TORCH_CHECK( |
192 | !has_symbolic_sizes_strides_, |
193 | "set_member_tensors called on tensor with symbolic shape" ); |
194 | |
195 | // CSR Type Invariants |
196 | TORCH_CHECK( |
197 | values.scalar_type() == typeMetaToScalarType(dtype()), |
198 | "dtype of values (" , |
199 | values.scalar_type(), |
200 | ") must match dtype of sparse tensor (" , |
201 | typeMetaToScalarType(dtype()), |
202 | ")" ); |
203 | crow_indices_ = crow_indices; |
204 | col_indices_ = col_indices; |
205 | values_ = values; |
206 | |
207 | sizes_and_strides_.set_sizes(size); |
208 | refresh_numel(); |
209 | // TODO: If this check ever shows up as a bottleneck, which is unlikely given that |
210 | // comparing devices only involves comparing the type and index (two integers), we |
211 | // can move this to a DEBUG only assert. Until then this confirms and maintains a |
212 | // crucial invariance. |
213 | TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and " , |
214 | at::sparse_csr::compressedIndicesName(layout_), " need to be on the same device." ); |
215 | TORCH_CHECK(values_.device() == col_indices_.device(), "Values and " , |
216 | at::sparse_csr::plainIndicesName(layout_), " need to be on the same device." ); |
217 | TORCH_CHECK(values_.device() == device(), |
218 | "Values and compressed tensor instance need to be on the same device." ); |
219 | } |
220 | |
221 | IntArrayRef SparseCsrTensorImpl::strides_custom() const { |
222 | TORCH_CHECK(false, "Sparse " , at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have strides" ); |
223 | } |
224 | SymIntArrayRef SparseCsrTensorImpl::sym_strides_custom() const { |
225 | TORCH_CHECK(false, "Sparse " , at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have strides" ); |
226 | } |
227 | void SparseCsrTensorImpl::set_size(int64_t dim, int64_t new_size) { |
228 | TORCH_CHECK(false, "Sparse " , at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_size." ); |
229 | } |
230 | void SparseCsrTensorImpl::set_stride(int64_t dim, int64_t new_stride) { |
231 | TORCH_CHECK(false, "Sparse " , at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_stride." ); |
232 | } |
233 | void SparseCsrTensorImpl::set_storage_offset(int64_t storage_offset) { |
234 | TORCH_CHECK(false, "Sparse " , at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_storage_offset." ); |
235 | } |
236 | bool SparseCsrTensorImpl::is_contiguous_custom(MemoryFormat) const { |
237 | TORCH_CHECK(false, "Sparse " , at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have is_contiguous" ); |
238 | } |
239 | |
240 | } // namespace at |
241 | |