1 | #include <c10/util/Exception.h> |
2 | #include <utility> |
3 | |
4 | namespace at { |
5 | |
6 | /* |
7 | [collapse dims] Updates sizes, and strides to reflect a "collapse" of |
8 | the info, possibly excluding the optional excludeDim. A "collapsed" version |
9 | of the info is the fewest dims that order the tensor's elements in the same |
10 | way as the original info. If excludeDim is specified, the collapse is the |
11 | fewest dims that order the tensor's elements as the original and preserve the |
12 | excluded dimension, unless the tensor collapses to a point. |
13 | |
14 | This function returns a pair of values. |
15 | |
16 | 1) The (new) index of the preserved dimension if excludeDim is |
17 | specified. 0 if the tensor is collapsed to a point. -1 |
18 | otherwise. |
19 | |
20 | 2) The new number of dimensions. |
21 | */ |
22 | template <typename T> |
23 | inline std::pair<int64_t, int64_t> collapse_dims( |
24 | T* sizes, |
25 | T* strides, |
26 | int64_t dims, |
27 | const int excludeDim = -1) { |
28 | TORCH_CHECK( |
29 | excludeDim >= -1 && excludeDim < dims, |
30 | "expected excluded dim between -1 and dims - 1" ); |
31 | |
32 | int64_t stopDim = (excludeDim == -1) ? dims : excludeDim; |
33 | int64_t newIndex = -1; |
34 | int64_t oldIndex = 0; |
35 | int64_t remappedExcludedDim = -1; |
36 | |
37 | while (oldIndex < dims) { |
38 | // Finds a dimension to collapse into |
39 | for (; oldIndex < stopDim; ++oldIndex) { |
40 | if (sizes[oldIndex] == 1) { |
41 | continue; |
42 | } |
43 | |
44 | ++newIndex; |
45 | sizes[newIndex] = sizes[oldIndex]; |
46 | strides[newIndex] = strides[oldIndex]; |
47 | ++oldIndex; |
48 | break; |
49 | } |
50 | |
51 | // Collapses dims |
52 | for (; oldIndex < stopDim; ++oldIndex) { |
53 | if (sizes[oldIndex] == 1) { |
54 | continue; |
55 | } |
56 | |
57 | if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) { |
58 | sizes[newIndex] *= sizes[oldIndex]; |
59 | strides[newIndex] = strides[oldIndex]; |
60 | } else { |
61 | ++newIndex; |
62 | sizes[newIndex] = sizes[oldIndex]; |
63 | strides[newIndex] = strides[oldIndex]; |
64 | } |
65 | } |
66 | |
67 | // Handles excludeDim being set (oldIndex == excludeDim) |
68 | if (oldIndex != dims) { |
69 | // Preserves excluded dimension |
70 | ++newIndex; |
71 | sizes[newIndex] = sizes[oldIndex]; |
72 | strides[newIndex] = strides[oldIndex]; |
73 | remappedExcludedDim = newIndex; |
74 | |
75 | // Restarts iteration after excludeDim |
76 | ++oldIndex; |
77 | stopDim = dims; |
78 | } |
79 | } |
80 | |
81 | // Handles special case of all dims size 1 |
82 | if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) { |
83 | dims = 1; |
84 | sizes[0] = 1; |
85 | strides[0] = 1; |
86 | |
87 | return std::pair<int64_t, int64_t>(0, 1); |
88 | } |
89 | |
90 | dims = newIndex + 1; |
91 | return std::pair<int64_t, int64_t>(remappedExcludedDim, dims); |
92 | } |
93 | |
94 | } // namespace at |
95 | |