1#pragma once
2
3#include <ATen/Parallel.h>
4#include <ATen/SparseTensorImpl.h>
5#include <ATen/core/Tensor.h>
6
7#ifndef AT_PER_OPERATOR_HEADERS
8#include <ATen/Functions.h>
9#else
10#include <ATen/ops/empty.h>
11#endif
12
13namespace at {
14namespace sparse {
15
16// Just for documentary purposes
17using SparseTensor = Tensor;
18using SparseType = Type;
19
20// This is an internal utility function for getting at the SparseTensorImpl,
21// so that we can write sparse tensor specific accessors for special fields
22// in SparseTensor. You should only use this for writing low level
23// setters/getters for SparseTensorImpl fields; otherwise, you should use
24// the low level setters/getters that were implemented using this.
25//
26// This may be called repeatedly, so make sure it's pretty cheap.
27inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) {
28 TORCH_INTERNAL_ASSERT(
29 self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor");
30 return static_cast<SparseTensorImpl*>(self.unsafeGetTensorImpl());
31}
32
33// Takes indices and values and directly puts them into the sparse tensor, no
34// copy. This used to be called THSTensor_(_move)
35inline void alias_into_sparse(
36 const SparseTensor& self,
37 const Tensor& indices,
38 const Tensor& values) {
39 get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values);
40}
41
42// Take indices and values and makes a (data) copy of them to put into the
43// sparse indices/values. This used to be called THSTensor_(_set)
44inline void copy_into_sparse(
45 const SparseTensor& self,
46 const Tensor& indices,
47 const Tensor& values,
48 bool non_blocking) {
49 alias_into_sparse(
50 self,
51 indices.to(self._indices().options(), non_blocking, /*copy=*/true),
52 values.to(self._values().options(), non_blocking, /*copy=*/true));
53}
54
55// TODO: put this into the public API
56inline bool is_same_tensor(const Tensor& lhs, const Tensor& rhs) {
57 return lhs.unsafeGetTensorImpl() == rhs.unsafeGetTensorImpl();
58}
59
60inline bool is_same_density(const SparseTensor& self, const SparseTensor& src) {
61 return self.sparse_dim() == src.sparse_dim() &&
62 self.dense_dim() == src.dense_dim();
63}
64
65// Give us a new values tensor, with the same dimensionality
66// as 'values' but with a new number of non-zero elements.
67// TODO: Expose this for real in ATen, some day?
68// NB: Doesn't preserve data.
69inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) {
70 std::vector<int64_t> size = values.sizes().vec();
71 size[0] = nnz;
72 return at::empty(size, values.options());
73}
74
75// NOTE [ Flatten Sparse Indices ]
76// This helper function flattens a sparse indices tensor (a Tensor) into a 1D
77// indices tensor. E.g.,
78// input = [[2, 4, 0],
79// [3, 1, 10]]
80// full_size = [2, 12]
81// output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
82//
83// In other words, assuming that each `indices[i, :]` is a valid index to a
84// tensor `t` of shape `full_size`. This returns the corresponding indices to
85// the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
86// if forceClone is true, the result will forced to be a clone of self.
87// if force_clone is true, the result will forced to be a clone of self.
88TORCH_API Tensor flatten_indices(
89 const Tensor& indices,
90 IntArrayRef full_size,
91 bool force_clone = false);
92
93// Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten
94// Sparse Indices ], except this one allows partial flatten: only flatten on
95// specified dims. Note that the flatten indices might be uncoalesced if
96// dims_to_flatten.size() < sparse_dim. Also if input indices is already
97// coalesced, the flattened indices will also be sorted.
98//
99// args:
100// indices: sparse tensor indices
101// sizes: sparse tensor sizes
102// dims_to_flatten: a list of dim index to flatten
103//
104// Ex1:
105// indices = [[2, 4, 0],
106// [3, 1, 3]]
107// sizes = [2, 12]
108// dims_to_flatten = [0, 1]
109// new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
110//
111// Ex2:
112// dims_to_flatten = [1]
113// new_indices = [ 3, 1, 3 ] # uncoalesced
114TORCH_API Tensor flatten_indices_by_dims(
115 const Tensor& indices,
116 const IntArrayRef& sizes,
117 const IntArrayRef& dims_to_flatten);
118
119// Find the CSR representation for a row `indices` from the COO format
120TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz);
121
122} // namespace sparse
123} // namespace at
124