1#pragma once
2
3#include <c10/util/ArrayRef.h>
4#include <c10/util/Exception.h>
5#include <c10/util/irange.h>
6
7#include <vector>
8
9namespace torch {
10namespace lazy {
11
12TORCH_API std::vector<int64_t> InversePermutation(
13 c10::ArrayRef<int64_t> input_permutation);
14
15TORCH_API bool IsPermutation(c10::ArrayRef<int64_t> permutation);
16
17// Gathers the input using the order specified by the permutation. For each i,
18// output[i] = dimensions[permutation[i]]. The given permutation must be the
19// same size as the input.
20template <typename Container>
21std::vector<typename Container::value_type> PermuteDimensions(
22 c10::ArrayRef<int64_t> permutation,
23 const Container& dimensions) {
24 using T = typename Container::value_type;
25 TORCH_CHECK(
26 dimensions.size() == permutation.size(),
27 "Invalid permutation specified. dimensions.size() != permutation.size() (",
28 dimensions.size(),
29 " vs. ",
30 permutation.size(),
31 ")");
32 TORCH_CHECK(
33 IsPermutation(permutation),
34 "Invalid permutation specified. Permutation is not permutation");
35 std::vector<T> output(dimensions.size());
36 for (const auto i : c10::irange(permutation.size())) {
37 output[i] = dimensions[permutation[i]];
38 }
39 return output;
40}
41
42} // namespace lazy
43} // namespace torch
44