1 | #pragma once |
---|---|
2 | |
3 | #include <c10/util/ArrayRef.h> |
4 | #include <c10/util/Exception.h> |
5 | #include <c10/util/irange.h> |
6 | |
7 | #include <vector> |
8 | |
9 | namespace torch { |
10 | namespace lazy { |
11 | |
12 | TORCH_API std::vector<int64_t> InversePermutation( |
13 | c10::ArrayRef<int64_t> input_permutation); |
14 | |
15 | TORCH_API bool IsPermutation(c10::ArrayRef<int64_t> permutation); |
16 | |
17 | // Gathers the input using the order specified by the permutation. For each i, |
18 | // output[i] = dimensions[permutation[i]]. The given permutation must be the |
19 | // same size as the input. |
20 | template <typename Container> |
21 | std::vector<typename Container::value_type> PermuteDimensions( |
22 | c10::ArrayRef<int64_t> permutation, |
23 | const Container& dimensions) { |
24 | using T = typename Container::value_type; |
25 | TORCH_CHECK( |
26 | dimensions.size() == permutation.size(), |
27 | "Invalid permutation specified. dimensions.size() != permutation.size() (", |
28 | dimensions.size(), |
29 | " vs. ", |
30 | permutation.size(), |
31 | ")"); |
32 | TORCH_CHECK( |
33 | IsPermutation(permutation), |
34 | "Invalid permutation specified. Permutation is not permutation"); |
35 | std::vector<T> output(dimensions.size()); |
36 | for (const auto i : c10::irange(permutation.size())) { |
37 | output[i] = dimensions[permutation[i]]; |
38 | } |
39 | return output; |
40 | } |
41 | |
42 | } // namespace lazy |
43 | } // namespace torch |
44 |