1 | #include <torch/csrc/lazy/core/helpers.h> |
2 | |
3 | #include <c10/util/Half.h> |
4 | #include <c10/util/irange.h> |
5 | #include <torch/csrc/lazy/core/tensor_util.h> |
6 | |
7 | #include <limits> |
8 | |
9 | namespace torch { |
10 | namespace lazy { |
11 | |
12 | std::vector<int64_t> DropDimensions( |
13 | c10::ArrayRef<int64_t> sizes, |
14 | c10::ArrayRef<int64_t> drop_dims) { |
15 | std::vector<int64_t> new_dims; |
16 | size_t drop_index = 0; |
17 | for (const auto i : c10::irange(sizes.size())) { |
18 | if (drop_index < drop_dims.size() && i == drop_dims[drop_index]) { |
19 | ++drop_index; |
20 | } else { |
21 | new_dims.push_back(sizes[i]); |
22 | } |
23 | } |
24 | TORCH_CHECK(drop_index == drop_dims.size()); |
25 | return new_dims; |
26 | } |
27 | |
28 | int64_t GetCanonicalDimensionIndex(int64_t dim, int64_t rank) { |
29 | int64_t min_shape_dim = -rank; |
30 | int64_t max_shape_dim = rank - 1; |
31 | TORCH_CHECK( |
32 | min_shape_dim <= dim && dim <= max_shape_dim, |
33 | "Value out of range (expected to be in range of [" , |
34 | min_shape_dim, |
35 | ", " , |
36 | max_shape_dim, |
37 | "], but got " , |
38 | dim, |
39 | ")" ); |
40 | int64_t dim_index = dim < 0 ? rank + dim : dim; |
41 | TORCH_CHECK(dim_index >= 0); |
42 | TORCH_CHECK(dim_index < rank); |
43 | return dim_index; |
44 | } |
45 | |
46 | std::vector<int64_t> GetCanonicalDimensionIndices( |
47 | c10::ArrayRef<int64_t> dimensions, |
48 | int64_t rank) { |
49 | std::vector<int64_t> canonical_dim_indices; |
50 | for (int64_t dim : dimensions) { |
51 | canonical_dim_indices.push_back(GetCanonicalDimensionIndex(dim, rank)); |
52 | } |
53 | return canonical_dim_indices; |
54 | } |
55 | |
56 | int64_t GetCanonicalPosition( |
57 | c10::ArrayRef<int64_t> dimensions, |
58 | int64_t dim, |
59 | int64_t pos) { |
60 | dim = GetCanonicalDimensionIndex(dim, dimensions.size()); |
61 | if (pos < 0) { |
62 | pos = GetCanonicalDimensionIndex(pos, dimensions[dim]); |
63 | } else { |
64 | pos = std::min<int64_t>(pos, dimensions[dim]); |
65 | } |
66 | return pos; |
67 | } |
68 | |
69 | std::vector<int64_t> MakeTransposePermutation( |
70 | int64_t dim0, |
71 | int64_t dim1, |
72 | int64_t rank) { |
73 | int64_t canonical_dim0 = GetCanonicalDimensionIndex(dim0, rank); |
74 | int64_t canonical_dim1 = GetCanonicalDimensionIndex(dim1, rank); |
75 | auto permute_dims = Iota<int64_t>(rank); |
76 | std::swap(permute_dims[canonical_dim0], permute_dims[canonical_dim1]); |
77 | return permute_dims; |
78 | } |
79 | |
80 | std::vector<int64_t> GetPromotedShape( |
81 | c10::ArrayRef<int64_t> shape1_dims, |
82 | c10::ArrayRef<int64_t> shape2_dims) { |
83 | std::vector<int64_t> dimensions; |
84 | // If the rank of a shape is bigger than then other, fill up the first |
85 | // dimensions with the ones of the bigger. |
86 | // Example: |
87 | // shape1 = [9, 7, 6, 5, 2] |
88 | // shape2 = [6, 1, 2] |
89 | // Insert [9, 7] into the dimensions vector. |
90 | if (shape1_dims.size() > shape2_dims.size()) { |
91 | dimensions.insert( |
92 | dimensions.end(), |
93 | shape1_dims.begin(), |
94 | shape1_dims.begin() + (shape1_dims.size() - shape2_dims.size())); |
95 | } else if (shape2_dims.size() > shape1_dims.size()) { |
96 | dimensions.insert( |
97 | dimensions.end(), |
98 | shape2_dims.begin(), |
99 | shape2_dims.begin() + (shape2_dims.size() - shape1_dims.size())); |
100 | } |
101 | // For the common dimensions, they must match, or one of them be 1. |
102 | size_t min_size = std::min(shape1_dims.size(), shape2_dims.size()); |
103 | for (const auto i : c10::irange(min_size)) { |
104 | int64_t dim1 = shape1_dims[shape1_dims.size() - min_size + i]; |
105 | int64_t dim2 = shape2_dims[shape2_dims.size() - min_size + i]; |
106 | TORCH_CHECK( |
107 | dim1 == dim2 || dim1 == 1 || dim2 == 1, |
108 | "(" , |
109 | c10::Join(", " , shape1_dims), |
110 | ") and (" , |
111 | c10::Join(", " , shape1_dims), |
112 | ")" ); |
113 | if (dim1 == 0 || dim2 == 0) { |
114 | dimensions.push_back(0); |
115 | } else { |
116 | dimensions.push_back(std::max<int64_t>(dim1, dim2)); |
117 | } |
118 | } |
119 | return dimensions; |
120 | } |
121 | |
122 | Shape GetPromotedBinaryOpShape(const Shape& shape1, const Shape& shape2) { |
123 | return Shape( |
124 | promoteTypes(shape1.scalar_type(), shape2.scalar_type()), |
125 | GetPromotedShape(shape1.sizes(), shape2.sizes())); |
126 | } |
127 | |
128 | std::vector<std::string> StrSplit(c10::string_view text, char delim) { |
129 | size_t start = 0; |
130 | size_t end = 0; |
131 | |
132 | std::vector<std::string> tokens; |
133 | while ((start = text.find_first_not_of(delim, end)) != std::string::npos) { |
134 | end = text.find(delim, start); |
135 | auto token = text.substr(start, end - start); |
136 | tokens.emplace_back(token.begin(), token.end()); |
137 | } |
138 | return tokens; |
139 | } |
140 | |
141 | } // namespace lazy |
142 | } // namespace torch |
143 | |