1#include <torch/csrc/lazy/core/helpers.h>
2
3#include <c10/util/Half.h>
4#include <c10/util/irange.h>
5#include <torch/csrc/lazy/core/tensor_util.h>
6
7#include <limits>
8
9namespace torch {
10namespace lazy {
11
12std::vector<int64_t> DropDimensions(
13 c10::ArrayRef<int64_t> sizes,
14 c10::ArrayRef<int64_t> drop_dims) {
15 std::vector<int64_t> new_dims;
16 size_t drop_index = 0;
17 for (const auto i : c10::irange(sizes.size())) {
18 if (drop_index < drop_dims.size() && i == drop_dims[drop_index]) {
19 ++drop_index;
20 } else {
21 new_dims.push_back(sizes[i]);
22 }
23 }
24 TORCH_CHECK(drop_index == drop_dims.size());
25 return new_dims;
26}
27
28int64_t GetCanonicalDimensionIndex(int64_t dim, int64_t rank) {
29 int64_t min_shape_dim = -rank;
30 int64_t max_shape_dim = rank - 1;
31 TORCH_CHECK(
32 min_shape_dim <= dim && dim <= max_shape_dim,
33 "Value out of range (expected to be in range of [",
34 min_shape_dim,
35 ", ",
36 max_shape_dim,
37 "], but got ",
38 dim,
39 ")");
40 int64_t dim_index = dim < 0 ? rank + dim : dim;
41 TORCH_CHECK(dim_index >= 0);
42 TORCH_CHECK(dim_index < rank);
43 return dim_index;
44}
45
46std::vector<int64_t> GetCanonicalDimensionIndices(
47 c10::ArrayRef<int64_t> dimensions,
48 int64_t rank) {
49 std::vector<int64_t> canonical_dim_indices;
50 for (int64_t dim : dimensions) {
51 canonical_dim_indices.push_back(GetCanonicalDimensionIndex(dim, rank));
52 }
53 return canonical_dim_indices;
54}
55
56int64_t GetCanonicalPosition(
57 c10::ArrayRef<int64_t> dimensions,
58 int64_t dim,
59 int64_t pos) {
60 dim = GetCanonicalDimensionIndex(dim, dimensions.size());
61 if (pos < 0) {
62 pos = GetCanonicalDimensionIndex(pos, dimensions[dim]);
63 } else {
64 pos = std::min<int64_t>(pos, dimensions[dim]);
65 }
66 return pos;
67}
68
69std::vector<int64_t> MakeTransposePermutation(
70 int64_t dim0,
71 int64_t dim1,
72 int64_t rank) {
73 int64_t canonical_dim0 = GetCanonicalDimensionIndex(dim0, rank);
74 int64_t canonical_dim1 = GetCanonicalDimensionIndex(dim1, rank);
75 auto permute_dims = Iota<int64_t>(rank);
76 std::swap(permute_dims[canonical_dim0], permute_dims[canonical_dim1]);
77 return permute_dims;
78}
79
80std::vector<int64_t> GetPromotedShape(
81 c10::ArrayRef<int64_t> shape1_dims,
82 c10::ArrayRef<int64_t> shape2_dims) {
83 std::vector<int64_t> dimensions;
84 // If the rank of a shape is bigger than then other, fill up the first
85 // dimensions with the ones of the bigger.
86 // Example:
87 // shape1 = [9, 7, 6, 5, 2]
88 // shape2 = [6, 1, 2]
89 // Insert [9, 7] into the dimensions vector.
90 if (shape1_dims.size() > shape2_dims.size()) {
91 dimensions.insert(
92 dimensions.end(),
93 shape1_dims.begin(),
94 shape1_dims.begin() + (shape1_dims.size() - shape2_dims.size()));
95 } else if (shape2_dims.size() > shape1_dims.size()) {
96 dimensions.insert(
97 dimensions.end(),
98 shape2_dims.begin(),
99 shape2_dims.begin() + (shape2_dims.size() - shape1_dims.size()));
100 }
101 // For the common dimensions, they must match, or one of them be 1.
102 size_t min_size = std::min(shape1_dims.size(), shape2_dims.size());
103 for (const auto i : c10::irange(min_size)) {
104 int64_t dim1 = shape1_dims[shape1_dims.size() - min_size + i];
105 int64_t dim2 = shape2_dims[shape2_dims.size() - min_size + i];
106 TORCH_CHECK(
107 dim1 == dim2 || dim1 == 1 || dim2 == 1,
108 "(",
109 c10::Join(", ", shape1_dims),
110 ") and (",
111 c10::Join(", ", shape1_dims),
112 ")");
113 if (dim1 == 0 || dim2 == 0) {
114 dimensions.push_back(0);
115 } else {
116 dimensions.push_back(std::max<int64_t>(dim1, dim2));
117 }
118 }
119 return dimensions;
120}
121
122Shape GetPromotedBinaryOpShape(const Shape& shape1, const Shape& shape2) {
123 return Shape(
124 promoteTypes(shape1.scalar_type(), shape2.scalar_type()),
125 GetPromotedShape(shape1.sizes(), shape2.sizes()));
126}
127
128std::vector<std::string> StrSplit(c10::string_view text, char delim) {
129 size_t start = 0;
130 size_t end = 0;
131
132 std::vector<std::string> tokens;
133 while ((start = text.find_first_not_of(delim, end)) != std::string::npos) {
134 end = text.find(delim, start);
135 auto token = text.substr(start, end - start);
136 tokens.emplace_back(token.begin(), token.end());
137 }
138 return tokens;
139}
140
141} // namespace lazy
142} // namespace torch
143