1 | #include <c10/util/irange.h> |
---|---|
2 | #include <torch/csrc/lazy/core/permutation_util.h> |
3 | |
4 | #include <algorithm> |
5 | #include <numeric> |
6 | |
7 | namespace torch { |
8 | namespace lazy { |
9 | |
10 | std::vector<int64_t> InversePermutation( |
11 | c10::ArrayRef<int64_t> input_permutation) { |
12 | TORCH_CHECK(IsPermutation(input_permutation)); |
13 | std::vector<int64_t> output_permutation(input_permutation.size(), -1); |
14 | for (const auto i : c10::irange(input_permutation.size())) { |
15 | output_permutation.at(input_permutation.at(i)) = i; |
16 | } |
17 | return output_permutation; |
18 | } |
19 | |
20 | bool IsPermutation(c10::ArrayRef<int64_t> permutation) { |
21 | std::vector<int64_t> trivial_permutation(permutation.size()); |
22 | std::iota(trivial_permutation.begin(), trivial_permutation.end(), 0); |
23 | return std::is_permutation( |
24 | permutation.begin(), permutation.end(), trivial_permutation.begin()); |
25 | } |
26 | |
27 | } // namespace lazy |
28 | } // namespace torch |
29 |