1 | #pragma once |
2 | |
3 | #include <ATen/Tensor.h> |
4 | #include <c10/core/TensorImpl.h> |
5 | #include <c10/util/Exception.h> |
6 | namespace at { |
7 | |
8 | // Struct implementing a sparse CSR tensor. It uses three 1-D tensors for |
9 | // denoting the data: `crow_indices_`, `col_indices_` and `values_`. |
10 | // The `crow_indices_` tensor is a integer tensor of shape `(size(0) + 1)` |
11 | // that represents the compressed row indices of the CSR tensor. The |
12 | // `col_indices_` tensor is an integer tensor of shape `(nnz())` |
13 | // that explicitly stores the column indices of each value of the sparse |
14 | // tensor. The `values_` tensor can be of any pytorch-supported data type |
15 | // and has shape `(nnz())`. |
16 | // |
17 | // Since the main advantage of the CSR format over the COO format is speed of |
18 | // computation, care must be taken to facilitate smooth interfacing of |
19 | // these data structures with optimized libraries such as MKL and MAGMA. |
20 | // Since the MKL interface for pytorch currently uses indexing with int32 |
21 | // type, it is important to make sure that the `crow_indices` and `col_indices` |
22 | // are of type int32 when calling MKL routines such as SPMM or SPMV. |
23 | // |
24 | // If not calling MKL, it should be alright to use 64 bit integer tensors |
25 | // for indexing. |
26 | struct TORCH_API SparseCsrTensorImpl : public TensorImpl { |
27 | Tensor crow_indices_; |
28 | Tensor col_indices_; |
29 | Tensor values_; |
30 | Layout layout_; |
31 | |
32 | public: |
33 | explicit SparseCsrTensorImpl( |
34 | at::DispatchKeySet, |
35 | at::Device device, |
36 | Layout layout, |
37 | const caffe2::TypeMeta); |
38 | |
39 | void resize_(int64_t nnz, IntArrayRef size); |
40 | void resize_and_clear_(int64_t sparse_dim, IntArrayRef size); |
41 | void resize_as_sparse_compressed_tensor_(const Tensor& src); |
42 | void set_member_tensors( |
43 | const Tensor& crow_indices, |
44 | const Tensor& col_indices, |
45 | const Tensor& values, |
46 | IntArrayRef size); |
47 | |
48 | const Tensor& compressed_indices() const { |
49 | return crow_indices_; |
50 | } |
51 | const Tensor& plain_indices() const { |
52 | return col_indices_; |
53 | } |
54 | const Tensor& values() const { |
55 | return values_; |
56 | } |
57 | int nnz() { |
58 | return col_indices_.size(-1); |
59 | } |
60 | |
61 | inline int64_t batch_dim() const noexcept { |
62 | return crow_indices_.dim() - 1; |
63 | } |
64 | |
65 | inline int64_t sparse_dim() const noexcept { |
66 | return 2; |
67 | } |
68 | |
69 | inline int64_t dense_dim() const noexcept { |
70 | return values_.dim() - batch_dim() - block_dim() - 1; |
71 | } |
72 | |
73 | private: |
74 | inline int64_t block_dim() const noexcept { |
75 | return (layout_ == kSparseBsr || layout_ == kSparseBsc ? 2 : 0); |
76 | } |
77 | |
78 | protected: |
79 | IntArrayRef strides_custom() const override; |
80 | SymIntArrayRef sym_strides_custom() const override; |
81 | bool is_contiguous_custom(MemoryFormat) const override; |
82 | |
83 | public: |
84 | void set_size(int64_t dim, int64_t new_size) override; |
85 | void set_stride(int64_t dim, int64_t new_stride) override; |
86 | void set_storage_offset(int64_t storage_offset) override; |
87 | Layout layout_impl() const override { |
88 | return layout_; |
89 | } |
90 | void set_layout(Layout layout) { |
91 | switch (layout) { |
92 | case kSparseCsr: |
93 | case kSparseCsc: |
94 | case kSparseBsr: |
95 | case kSparseBsc: |
96 | layout_ = layout; |
97 | break; |
98 | default: |
99 | TORCH_CHECK(false, "unsupported layout " , layout); |
100 | } |
101 | } |
102 | |
103 | /** |
104 | * Return a TensorImpl that is a shallow-copy of this TensorImpl. |
105 | * |
106 | * For usage of `version_counter` and `allow_tensor_metadata_change`, |
107 | * see NOTE [ TensorImpl Shallow-Copying ]. |
108 | */ |
109 | c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
110 | const c10::VariableVersion& version_counter, |
111 | bool allow_tensor_metadata_change) const override { |
112 | auto impl = c10::make_intrusive<SparseCsrTensorImpl>( |
113 | key_set(), device(), layout_impl(), dtype()); |
114 | copy_tensor_metadata( |
115 | /*src_impl=*/this, |
116 | /*dest_impl=*/impl.get(), |
117 | /*version_counter=*/version_counter, |
118 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); |
119 | impl->refresh_numel(); |
120 | return impl; |
121 | } |
122 | |
123 | /** |
124 | * Return a TensorImpl that is a shallow-copy of this TensorImpl. |
125 | * |
126 | * For usage of `version_counter` and `allow_tensor_metadata_change`, |
127 | * see NOTE [ TensorImpl Shallow-Copying ]. |
128 | */ |
129 | c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
130 | c10::VariableVersion&& version_counter, |
131 | bool allow_tensor_metadata_change) const override { |
132 | auto impl = c10::make_intrusive<SparseCsrTensorImpl>( |
133 | key_set(), device(), layout_impl(), dtype()); |
134 | copy_tensor_metadata( |
135 | /*src_impl=*/this, |
136 | /*dest_impl=*/impl.get(), |
137 | /*version_counter=*/std::move(version_counter), |
138 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); |
139 | impl->refresh_numel(); |
140 | return impl; |
141 | } |
142 | |
143 | private: |
144 | explicit SparseCsrTensorImpl( |
145 | at::DispatchKeySet key_set, |
146 | const caffe2::TypeMeta data_type, |
147 | at::Tensor crow_indices, |
148 | at::Tensor col_indices, |
149 | at::Tensor values, |
150 | at::Layout layout); |
151 | |
152 | const char* tensorimpl_type_name() const override; |
153 | |
154 | /** |
155 | * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / |
156 | * storage_offset) from one TensorImpl to another TensorImpl. |
157 | * |
158 | * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE |
159 | * [ TensorImpl Shallow-Copying ]. |
160 | */ |
161 | static void copy_tensor_metadata( |
162 | const SparseCsrTensorImpl* src_sparse_impl, |
163 | SparseCsrTensorImpl* dest_sparse_impl, |
164 | const c10::VariableVersion& version_counter, |
165 | bool allow_tensor_metadata_change) { |
166 | TensorImpl::copy_tensor_metadata( |
167 | src_sparse_impl, |
168 | dest_sparse_impl, |
169 | version_counter, |
170 | allow_tensor_metadata_change); |
171 | |
172 | // Sparse-specific fields |
173 | dest_sparse_impl->crow_indices_ = src_sparse_impl->compressed_indices(); |
174 | dest_sparse_impl->col_indices_ = src_sparse_impl->plain_indices(); |
175 | dest_sparse_impl->values_ = src_sparse_impl->values(); |
176 | dest_sparse_impl->layout_ = src_sparse_impl->layout_impl(); |
177 | } |
178 | }; |
179 | } // namespace at |
180 | |