1 | #pragma once |
2 | |
3 | #include <c10/core/Scalar.h> |
4 | #include <c10/util/BFloat16.h> |
5 | #include <c10/util/Half.h> |
6 | #include <c10/util/Optional.h> |
7 | #include <torch/csrc/lazy/core/permutation_util.h> |
8 | #include <torch/csrc/lazy/core/shape.h> |
9 | #include <torch/csrc/lazy/core/util.h> |
10 | |
11 | #include <complex> |
12 | #include <functional> |
13 | #include <tuple> |
14 | #include <vector> |
15 | |
16 | // TODO: Consolidate this file with util.h |
17 | |
18 | namespace torch { |
19 | namespace lazy { |
20 | |
21 | // Converts an iterable container to a vector of int64's. |
22 | template <typename S> |
23 | static std::vector<int64_t> ToI64Vector(const S& input) { |
24 | return ToVector<int64_t>(input); |
25 | } |
26 | |
27 | // Creates a set of dimension by dropping the drop_dims ones. |
28 | TORCH_API std::vector<int64_t> DropDimensions( |
29 | c10::ArrayRef<int64_t> sizes, |
30 | c10::ArrayRef<int64_t> drop_dims); |
31 | |
32 | // Get the canonical dimension index in the [0, rank) interval. Negative |
33 | // indices are interpreted as follows: -1 is rank-1, -2 is rank-2 etc. |
34 | TORCH_API int64_t GetCanonicalDimensionIndex(int64_t dim, int64_t rank); |
35 | |
36 | // Same as above, for multiple dimensions. |
37 | TORCH_API std::vector<int64_t> GetCanonicalDimensionIndices( |
38 | c10::ArrayRef<int64_t> dimensions, |
39 | int64_t rank); |
40 | |
41 | // Returns the canonical position in the dim dimension, handling negative |
42 | // values for the position. |
43 | TORCH_API int64_t GetCanonicalPosition( |
44 | c10::ArrayRef<int64_t> dimensions, |
45 | int64_t dim, |
46 | int64_t pos); |
47 | |
48 | // Creates a transposition from the given input and dimensions. |
49 | TORCH_API std::vector<int64_t> MakeTransposePermutation( |
50 | int64_t dim0, |
51 | int64_t dim1, |
52 | int64_t rank); |
53 | |
54 | // Calculates the protomoted shape to which the input shapes should be |
55 | // broadcasted for an elementwise operation. The size of the common dimensions |
56 | // (2,3,4 for shape1, and 0,1,2 for shape2) must either match, or either one |
57 | // of the two be 1. |
58 | // Example: |
59 | // shape1 = [9, 7, 6, 1, 2] |
60 | // shape2 = [6, 5, 2] |
61 | // result_shape = [9, 7, 6, 5, 2] |
62 | TORCH_API std::vector<int64_t> GetPromotedShape( |
63 | c10::ArrayRef<int64_t> shape1_dims, |
64 | c10::ArrayRef<int64_t> shape2_dims); |
65 | |
66 | TORCH_API Shape |
67 | GetPromotedBinaryOpShape(const Shape& shape1, const Shape& shape2); |
68 | |
69 | TORCH_API std::vector<std::string> StrSplit(c10::string_view text, char delim); |
70 | |
71 | } // namespace lazy |
72 | } // namespace torch |
73 | |