1#pragma once
2
3#include <ATen/Tensor.h>
4#include <c10/core/TensorImpl.h>
5#include <c10/util/Exception.h>
6namespace at {
7
8// Struct implementing a sparse CSR tensor. It uses three 1-D tensors for
9// denoting the data: `crow_indices_`, `col_indices_` and `values_`.
10// The `crow_indices_` tensor is a integer tensor of shape `(size(0) + 1)`
11// that represents the compressed row indices of the CSR tensor. The
12// `col_indices_` tensor is an integer tensor of shape `(nnz())`
13// that explicitly stores the column indices of each value of the sparse
14// tensor. The `values_` tensor can be of any pytorch-supported data type
15// and has shape `(nnz())`.
16//
17// Since the main advantage of the CSR format over the COO format is speed of
18// computation, care must be taken to facilitate smooth interfacing of
19// these data structures with optimized libraries such as MKL and MAGMA.
20// Since the MKL interface for pytorch currently uses indexing with int32
21// type, it is important to make sure that the `crow_indices` and `col_indices`
22// are of type int32 when calling MKL routines such as SPMM or SPMV.
23//
24// If not calling MKL, it should be alright to use 64 bit integer tensors
25// for indexing.
26struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
27 Tensor crow_indices_;
28 Tensor col_indices_;
29 Tensor values_;
30 Layout layout_;
31
32 public:
33 explicit SparseCsrTensorImpl(
34 at::DispatchKeySet,
35 at::Device device,
36 Layout layout,
37 const caffe2::TypeMeta);
38
39 void resize_(int64_t nnz, IntArrayRef size);
40 void resize_and_clear_(int64_t sparse_dim, IntArrayRef size);
41 void resize_as_sparse_compressed_tensor_(const Tensor& src);
42 void set_member_tensors(
43 const Tensor& crow_indices,
44 const Tensor& col_indices,
45 const Tensor& values,
46 IntArrayRef size);
47
48 const Tensor& compressed_indices() const {
49 return crow_indices_;
50 }
51 const Tensor& plain_indices() const {
52 return col_indices_;
53 }
54 const Tensor& values() const {
55 return values_;
56 }
57 int nnz() {
58 return col_indices_.size(-1);
59 }
60
61 inline int64_t batch_dim() const noexcept {
62 return crow_indices_.dim() - 1;
63 }
64
65 inline int64_t sparse_dim() const noexcept {
66 return 2;
67 }
68
69 inline int64_t dense_dim() const noexcept {
70 return values_.dim() - batch_dim() - block_dim() - 1;
71 }
72
73 private:
74 inline int64_t block_dim() const noexcept {
75 return (layout_ == kSparseBsr || layout_ == kSparseBsc ? 2 : 0);
76 }
77
78 protected:
79 IntArrayRef strides_custom() const override;
80 SymIntArrayRef sym_strides_custom() const override;
81 bool is_contiguous_custom(MemoryFormat) const override;
82
83 public:
84 void set_size(int64_t dim, int64_t new_size) override;
85 void set_stride(int64_t dim, int64_t new_stride) override;
86 void set_storage_offset(int64_t storage_offset) override;
87 Layout layout_impl() const override {
88 return layout_;
89 }
90 void set_layout(Layout layout) {
91 switch (layout) {
92 case kSparseCsr:
93 case kSparseCsc:
94 case kSparseBsr:
95 case kSparseBsc:
96 layout_ = layout;
97 break;
98 default:
99 TORCH_CHECK(false, "unsupported layout ", layout);
100 }
101 }
102
103 /**
104 * Return a TensorImpl that is a shallow-copy of this TensorImpl.
105 *
106 * For usage of `version_counter` and `allow_tensor_metadata_change`,
107 * see NOTE [ TensorImpl Shallow-Copying ].
108 */
109 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
110 const c10::VariableVersion& version_counter,
111 bool allow_tensor_metadata_change) const override {
112 auto impl = c10::make_intrusive<SparseCsrTensorImpl>(
113 key_set(), device(), layout_impl(), dtype());
114 copy_tensor_metadata(
115 /*src_impl=*/this,
116 /*dest_impl=*/impl.get(),
117 /*version_counter=*/version_counter,
118 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
119 impl->refresh_numel();
120 return impl;
121 }
122
123 /**
124 * Return a TensorImpl that is a shallow-copy of this TensorImpl.
125 *
126 * For usage of `version_counter` and `allow_tensor_metadata_change`,
127 * see NOTE [ TensorImpl Shallow-Copying ].
128 */
129 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
130 c10::VariableVersion&& version_counter,
131 bool allow_tensor_metadata_change) const override {
132 auto impl = c10::make_intrusive<SparseCsrTensorImpl>(
133 key_set(), device(), layout_impl(), dtype());
134 copy_tensor_metadata(
135 /*src_impl=*/this,
136 /*dest_impl=*/impl.get(),
137 /*version_counter=*/std::move(version_counter),
138 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
139 impl->refresh_numel();
140 return impl;
141 }
142
143 private:
144 explicit SparseCsrTensorImpl(
145 at::DispatchKeySet key_set,
146 const caffe2::TypeMeta data_type,
147 at::Tensor crow_indices,
148 at::Tensor col_indices,
149 at::Tensor values,
150 at::Layout layout);
151
152 const char* tensorimpl_type_name() const override;
153
154 /**
155 * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
156 * storage_offset) from one TensorImpl to another TensorImpl.
157 *
158 * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
159 * [ TensorImpl Shallow-Copying ].
160 */
161 static void copy_tensor_metadata(
162 const SparseCsrTensorImpl* src_sparse_impl,
163 SparseCsrTensorImpl* dest_sparse_impl,
164 const c10::VariableVersion& version_counter,
165 bool allow_tensor_metadata_change) {
166 TensorImpl::copy_tensor_metadata(
167 src_sparse_impl,
168 dest_sparse_impl,
169 version_counter,
170 allow_tensor_metadata_change);
171
172 // Sparse-specific fields
173 dest_sparse_impl->crow_indices_ = src_sparse_impl->compressed_indices();
174 dest_sparse_impl->col_indices_ = src_sparse_impl->plain_indices();
175 dest_sparse_impl->values_ = src_sparse_impl->values();
176 dest_sparse_impl->layout_ = src_sparse_impl->layout_impl();
177 }
178};
179} // namespace at
180