1 | #pragma once |
---|---|
2 | #include <c10/util/ArrayRef.h> |
3 | #include <c10/util/DimVector.h> |
4 | |
5 | namespace c10 { |
6 | |
7 | // Computes the contiguous strides of a tensor, given its sizes. |
8 | static inline DimVector contiguous_strides(const IntArrayRef sizes) { |
9 | using Int = IntArrayRef::value_type; |
10 | const Int dims = static_cast<Int>(sizes.size()); |
11 | |
12 | // With this intialisation we get the case dim == 0 or 1 right |
13 | DimVector strides(dims, 1); |
14 | |
15 | for (auto i = dims - 2; i >= 0; --i) { |
16 | // Strides can't be 0 even if sizes are 0. |
17 | strides[i] = strides[i + 1] * std::max(sizes[i + 1], Int{1}); |
18 | } |
19 | |
20 | return strides; |
21 | } |
22 | |
23 | } // namespace c10 |
24 |