1 | // Copyright 2022 Google LLC |
2 | // |
3 | // This source code is licensed under the BSD-style license found in the |
4 | // LICENSE file in the root directory of this source tree. |
5 | |
6 | #include <stdbool.h> |
7 | #include <stddef.h> |
8 | #include <string.h> |
9 | |
10 | #include <xnnpack/math.h> |
11 | |
12 | // Returns true if input stride and output stride are NULL or the expected input/output stride matches the actual input/output stride. |
13 | static bool can_dimension_be_removed( |
14 | const size_t* input_stride, |
15 | const size_t* output_stride, |
16 | const size_t* shape, |
17 | const size_t* perm, |
18 | size_t dim) { |
19 | if (dim == 0 && perm[dim] == 0) { |
20 | return true; |
21 | } |
22 | if (input_stride != NULL && dim > 0) { |
23 | if (input_stride[dim - 1] != input_stride[dim] * shape[dim]) { |
24 | return false; |
25 | } |
26 | } |
27 | if (output_stride != NULL && perm[dim] > 0) { |
28 | if (output_stride[perm[dim] - 1] != output_stride[perm[dim]] * shape[dim]) { |
29 | return false; |
30 | } |
31 | } |
32 | return true; |
33 | } |
34 | |
35 | // Remove dimension perm[dim] from shape, perm, input & output strides. |
36 | static void remove_dimension( |
37 | size_t* shape, |
38 | size_t* perm, |
39 | size_t* input_stride, |
40 | size_t* output_stride, |
41 | size_t num_dims, |
42 | size_t dim) |
43 | { |
44 | for (size_t j = perm[dim]; j + 1 < num_dims; ++j) { |
45 | shape[j] = shape[j + 1]; |
46 | } |
47 | if (input_stride != NULL) { |
48 | for (size_t j = max(1, perm[dim]) - 1; j + 1 < num_dims; ++j) { |
49 | input_stride[j] = input_stride[j + 1]; |
50 | } |
51 | } |
52 | if (output_stride != NULL) { |
53 | for (size_t j = max(1, dim) - 1; j + 1 < num_dims; ++j) { |
54 | output_stride[j] = output_stride[j + 1]; |
55 | } |
56 | } |
57 | for (size_t j = 0; j < num_dims; ++j) { |
58 | if (perm[j] > perm[dim]) { |
59 | perm[j] -= 1; |
60 | } |
61 | } |
62 | for (size_t j = dim; j + 1 < num_dims; ++j) { |
63 | perm[j] = perm[j + 1]; |
64 | } |
65 | } |
66 | void xnn_normalize_transpose_permutation( |
67 | const size_t num_dims, |
68 | const size_t element_size, |
69 | const size_t* perm, |
70 | const size_t* shape, |
71 | const size_t* input_stride, |
72 | const size_t* output_stride, |
73 | size_t* normalized_num_dims, |
74 | size_t* normalized_element_size_out, |
75 | size_t* normalized_perm, |
76 | size_t* normalized_shape, |
77 | size_t* normalized_input_stride, |
78 | size_t* normalized_output_stride) |
79 | { |
80 | size_t output_dims = num_dims; |
81 | memcpy(normalized_perm, perm, num_dims * sizeof(size_t)); |
82 | memcpy(normalized_shape, shape, num_dims * sizeof(size_t)); |
83 | size_t* normalized_input_stride_ptr = NULL; |
84 | size_t* normalized_output_stride_ptr = NULL; |
85 | if (input_stride != NULL) { |
86 | memcpy(normalized_input_stride, input_stride, num_dims * sizeof(size_t)); |
87 | normalized_input_stride_ptr = normalized_input_stride; |
88 | } |
89 | if (output_stride != NULL) { |
90 | memcpy(normalized_output_stride, output_stride, num_dims * sizeof(size_t)); |
91 | normalized_output_stride_ptr = normalized_output_stride; |
92 | } |
93 | |
94 | size_t output_pos = 0; |
95 | // Remove dimensions of size 1 and fold dimensions which are adjacent in both input and output tensors. |
96 | for (; output_pos < output_dims;) { |
97 | if (can_dimension_be_removed(normalized_input_stride_ptr, normalized_output_stride_ptr, normalized_shape, |
98 | normalized_perm, normalized_perm[output_pos]) |
99 | && ((normalized_shape[normalized_perm[output_pos]] == 1) |
100 | || (output_pos > 0 && normalized_perm[output_pos] == normalized_perm[output_pos - 1] + 1))) { |
101 | if (output_pos > 0) { |
102 | normalized_shape[normalized_perm[output_pos - 1]] *= normalized_shape[normalized_perm[output_pos]]; |
103 | } |
104 | remove_dimension(normalized_shape, normalized_perm, normalized_input_stride_ptr, normalized_output_stride_ptr, |
105 | output_dims, output_pos); |
106 | output_dims -= 1; |
107 | // When a dimension has been removed, new folds may be possible so check |
108 | // it again. |
109 | if (output_pos > 0) { |
110 | output_pos -= 1; |
111 | } |
112 | } else { |
113 | output_pos += 1; |
114 | } |
115 | } |
116 | // All dimensions are size 1. |
117 | if (output_pos == 0) { |
118 | *normalized_num_dims = 1; |
119 | *normalized_element_size_out = element_size; |
120 | normalized_perm[0] = 0; |
121 | normalized_shape[0] = 1; |
122 | normalized_input_stride[0] = element_size; |
123 | normalized_output_stride[0] = element_size; |
124 | return; |
125 | } |
126 | |
127 | // If The last input and output dimensions are the same, treat it as one large |
128 | // element. |
129 | size_t normalized_element_size = element_size; |
130 | if (normalized_perm[output_dims - 1] == output_dims - 1) { |
131 | normalized_element_size = element_size * normalized_shape[output_dims - 1]; |
132 | if (output_dims > 1 && can_dimension_be_removed(normalized_input_stride_ptr, normalized_output_stride_ptr, normalized_shape, |
133 | normalized_perm, output_dims - 1)) { |
134 | output_dims -= 1; |
135 | } else { |
136 | if (normalized_input_stride != NULL) { |
137 | normalized_input_stride[output_dims - 1] *= normalized_shape[output_dims - 1]; |
138 | } |
139 | if (normalized_output_stride != NULL) { |
140 | normalized_output_stride[normalized_perm[output_dims - 1]] *= normalized_shape[output_dims - 1]; |
141 | } |
142 | normalized_shape[output_dims - 1] = 1; |
143 | } |
144 | } |
145 | // If input_strides is not provided, calculate it using normalized_shape and normalized_element_size. |
146 | if (input_stride == NULL) { |
147 | normalized_input_stride[output_dims - 1] = normalized_element_size; |
148 | for(size_t i = output_dims - 1; i > 0; --i) { |
149 | normalized_input_stride[i - 1] = normalized_input_stride[i] * normalized_shape[i]; |
150 | } |
151 | } else { |
152 | // Scale input_stride by element size. |
153 | for (size_t i = 0; i < output_dims; ++i) { |
154 | normalized_input_stride[i] *= element_size; |
155 | } |
156 | } |
157 | // If output_strides is not provided, calculate it using normalized_shape and normalized_element_size. |
158 | if (output_stride == NULL) { |
159 | normalized_output_stride[output_dims - 1] = normalized_element_size; |
160 | for(size_t i = output_dims - 1; i > 0; --i) { |
161 | normalized_output_stride[i - 1] = normalized_output_stride[i] * normalized_shape[normalized_perm[i]]; |
162 | } |
163 | } else { |
164 | // Scale output_stride by element size. |
165 | for (size_t i = 0; i < output_dims; ++i) { |
166 | normalized_output_stride[i] *= element_size; |
167 | } |
168 | } |
169 | *normalized_element_size_out = normalized_element_size; |
170 | *normalized_num_dims = output_dims; |
171 | } |
172 | |