1#include <c10/util/Exception.h>
2#include <utility>
3
4namespace at {
5
6/*
7[collapse dims] Updates sizes, and strides to reflect a "collapse" of
8the info, possibly excluding the optional excludeDim. A "collapsed" version
9of the info is the fewest dims that order the tensor's elements in the same
10way as the original info. If excludeDim is specified, the collapse is the
11fewest dims that order the tensor's elements as the original and preserve the
12excluded dimension, unless the tensor collapses to a point.
13
14This function returns a pair of values.
15
161) The (new) index of the preserved dimension if excludeDim is
17specified. 0 if the tensor is collapsed to a point. -1
18otherwise.
19
202) The new number of dimensions.
21*/
22template <typename T>
23inline std::pair<int64_t, int64_t> collapse_dims(
24 T* sizes,
25 T* strides,
26 int64_t dims,
27 const int excludeDim = -1) {
28 TORCH_CHECK(
29 excludeDim >= -1 && excludeDim < dims,
30 "expected excluded dim between -1 and dims - 1");
31
32 int64_t stopDim = (excludeDim == -1) ? dims : excludeDim;
33 int64_t newIndex = -1;
34 int64_t oldIndex = 0;
35 int64_t remappedExcludedDim = -1;
36
37 while (oldIndex < dims) {
38 // Finds a dimension to collapse into
39 for (; oldIndex < stopDim; ++oldIndex) {
40 if (sizes[oldIndex] == 1) {
41 continue;
42 }
43
44 ++newIndex;
45 sizes[newIndex] = sizes[oldIndex];
46 strides[newIndex] = strides[oldIndex];
47 ++oldIndex;
48 break;
49 }
50
51 // Collapses dims
52 for (; oldIndex < stopDim; ++oldIndex) {
53 if (sizes[oldIndex] == 1) {
54 continue;
55 }
56
57 if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) {
58 sizes[newIndex] *= sizes[oldIndex];
59 strides[newIndex] = strides[oldIndex];
60 } else {
61 ++newIndex;
62 sizes[newIndex] = sizes[oldIndex];
63 strides[newIndex] = strides[oldIndex];
64 }
65 }
66
67 // Handles excludeDim being set (oldIndex == excludeDim)
68 if (oldIndex != dims) {
69 // Preserves excluded dimension
70 ++newIndex;
71 sizes[newIndex] = sizes[oldIndex];
72 strides[newIndex] = strides[oldIndex];
73 remappedExcludedDim = newIndex;
74
75 // Restarts iteration after excludeDim
76 ++oldIndex;
77 stopDim = dims;
78 }
79 }
80
81 // Handles special case of all dims size 1
82 if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) {
83 dims = 1;
84 sizes[0] = 1;
85 strides[0] = 1;
86
87 return std::pair<int64_t, int64_t>(0, 1);
88 }
89
90 dims = newIndex + 1;
91 return std::pair<int64_t, int64_t>(remappedExcludedDim, dims);
92}
93
94} // namespace at
95