1// Copyright 2022 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <stdbool.h>
7#include <stddef.h>
8#include <string.h>
9
10#include <xnnpack.h>
11#include <xnnpack/math.h>
12
13void xnn_normalize_slice(
14 const size_t num_dims,
15 const size_t offsets[XNN_MIN_ELEMENTS(1)],
16 const size_t sizes[XNN_MIN_ELEMENTS(1)],
17 const size_t input_shape[XNN_MIN_ELEMENTS(1)],
18 size_t normalized_offsets[XNN_MIN_ELEMENTS(XNN_MAX_TENSOR_DIMS)],
19 size_t normalized_input_shape[XNN_MIN_ELEMENTS(XNN_MAX_TENSOR_DIMS)],
20 size_t normalized_output_shape[XNN_MIN_ELEMENTS(XNN_MAX_TENSOR_DIMS)],
21 size_t* num_normalized_dims)
22{
23 *num_normalized_dims = num_dims;
24 for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
25 normalized_offsets[i] = 0;
26 normalized_input_shape[i] = 1;
27 normalized_output_shape[i] = 1;
28 }
29
30 // First normalization pass will remove all slices of size 1, by merging it to an adjacent inner dimension.
31 size_t num_size_one = 0;
32 for (size_t i = 0; i < num_dims; i++) {
33 const size_t offset = offsets[num_dims - 1 - i];
34 const size_t size = sizes[num_dims - 1 - i];
35 const size_t input_dim = input_shape[num_dims - 1 - i];
36
37 // If the innermost dimension is size 1, we can't merge it anywhere, so skip it.
38 if (size == 1 && i != 0) {
39 normalized_offsets[XNN_MAX_TENSOR_DIMS - 1 - i + 1 + num_size_one] +=
40 offset * normalized_input_shape[XNN_MAX_TENSOR_DIMS - 1 - i + 1 + num_size_one];
41 normalized_input_shape[XNN_MAX_TENSOR_DIMS - 1 - i + 1 + num_size_one] *= input_dim;
42 normalized_output_shape[XNN_MAX_TENSOR_DIMS - 1 - i + 1 + num_size_one] *= size;
43 num_size_one++;
44 } else {
45 normalized_offsets[XNN_MAX_TENSOR_DIMS - 1 - i + num_size_one] = offset;
46 normalized_input_shape[XNN_MAX_TENSOR_DIMS - 1 - i + num_size_one] = input_dim;
47 normalized_output_shape[XNN_MAX_TENSOR_DIMS - 1 - i + num_size_one] = size;
48 }
49 }
50
51 size_t new_num_dims = num_dims - num_size_one;
52 size_t output_dims = new_num_dims;
53 bool merge_previous_dim = false;
54 size_t num_sliced_dims = 0;
55 for (size_t i = 0; i < new_num_dims; i++) {
56 const size_t offset = normalized_offsets[XNN_MAX_TENSOR_DIMS - 1 - i];
57 const size_t size = normalized_output_shape[XNN_MAX_TENSOR_DIMS - 1 - i];
58 const size_t input_dim = normalized_input_shape[XNN_MAX_TENSOR_DIMS - 1 - i];
59
60 const bool merge_current_dim = (offset == 0 && size == input_dim) ;
61 if (merge_previous_dim) {
62 normalized_offsets[XNN_MAX_TENSOR_DIMS - 1 - num_sliced_dims] =
63 offset * normalized_input_shape[XNN_MAX_TENSOR_DIMS - 1 - num_sliced_dims];
64 normalized_input_shape[XNN_MAX_TENSOR_DIMS - 1 - num_sliced_dims] *= input_dim;
65 normalized_output_shape[XNN_MAX_TENSOR_DIMS - 1 - num_sliced_dims] *= size;
66 output_dims -= 1;
67 if (!merge_current_dim) {
68 num_sliced_dims += 1;
69 }
70 } else {
71 normalized_offsets[XNN_MAX_TENSOR_DIMS - 1 - num_sliced_dims] = offset;
72 normalized_input_shape[XNN_MAX_TENSOR_DIMS - 1 - num_sliced_dims] = input_dim;
73 normalized_output_shape[XNN_MAX_TENSOR_DIMS - 1 - num_sliced_dims] = size;
74 if (!merge_current_dim) {
75 // If merge_current_dim, we can merge current dimension with the next dim, so don't advance num_sliced_dims.
76 num_sliced_dims += 1;
77 }
78 }
79 merge_previous_dim = merge_current_dim;
80 }
81
82 // new_num_dims <= num_dims due to merge of size == 1, so we are left with some extra values at the front of the
83 // normalized values, set them to default values.
84 for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS - output_dims; i++) {
85 normalized_offsets[i] = 0;
86 normalized_input_shape[i] = 1;
87 normalized_output_shape[i] = 1;
88 }
89
90 *num_normalized_dims = output_dims;
91}
92
93// Returns true if input stride and output stride are NULL or the expected input/output stride matches the actual input/output stride.
94static bool can_dimension_be_removed(
95 const size_t* input_stride,
96 const size_t* output_stride,
97 const size_t* shape,
98 const size_t* perm,
99 size_t dim) {
100 if (dim == 0 && perm[dim] == 0) {
101 return true;
102 }
103 if (input_stride != NULL && dim > 0) {
104 if (input_stride[dim - 1] != input_stride[dim] * shape[dim]) {
105 return false;
106 }
107 }
108 if (output_stride != NULL && perm[dim] > 0) {
109 if (output_stride[perm[dim] - 1] != output_stride[perm[dim]] * shape[dim]) {
110 return false;
111 }
112 }
113 return true;
114}
115
116// Remove dimension perm[dim] from shape, perm, input & output strides.
117static void remove_dimension(
118 size_t* shape,
119 size_t* perm,
120 size_t* input_stride,
121 size_t* output_stride,
122 size_t num_dims,
123 size_t dim)
124{
125 for (size_t j = perm[dim]; j + 1 < num_dims; ++j) {
126 shape[j] = shape[j + 1];
127 }
128 if (input_stride != NULL) {
129 for (size_t j = max(1, perm[dim]) - 1; j + 1 < num_dims; ++j) {
130 input_stride[j] = input_stride[j + 1];
131 }
132 }
133 if (output_stride != NULL) {
134 for (size_t j = max(1, dim) - 1; j + 1 < num_dims; ++j) {
135 output_stride[j] = output_stride[j + 1];
136 }
137 }
138 for (size_t j = 0; j < num_dims; ++j) {
139 if (perm[j] > perm[dim]) {
140 perm[j] -= 1;
141 }
142 }
143 for (size_t j = dim; j + 1 < num_dims; ++j) {
144 perm[j] = perm[j + 1];
145 }
146}
147void xnn_normalize_transpose_permutation(
148 const size_t num_dims,
149 const size_t element_size,
150 const size_t* perm,
151 const size_t* shape,
152 const size_t* input_stride,
153 const size_t* output_stride,
154 size_t* normalized_num_dims,
155 size_t* normalized_element_size_out,
156 size_t* normalized_perm,
157 size_t* normalized_shape,
158 size_t* normalized_input_stride,
159 size_t* normalized_output_stride)
160{
161 size_t output_dims = num_dims;
162 memcpy(normalized_perm, perm, num_dims * sizeof(size_t));
163 memcpy(normalized_shape, shape, num_dims * sizeof(size_t));
164 size_t* normalized_input_stride_ptr = NULL;
165 size_t* normalized_output_stride_ptr = NULL;
166 if (input_stride != NULL) {
167 memcpy(normalized_input_stride, input_stride, num_dims * sizeof(size_t));
168 normalized_input_stride_ptr = normalized_input_stride;
169 }
170 if (output_stride != NULL) {
171 memcpy(normalized_output_stride, output_stride, num_dims * sizeof(size_t));
172 normalized_output_stride_ptr = normalized_output_stride;
173 }
174
175 size_t output_pos = 0;
176 // Remove dimensions of size 1 and fold dimensions which are adjacent in both input and output tensors.
177 for (; output_pos < output_dims;) {
178 if (can_dimension_be_removed(normalized_input_stride_ptr, normalized_output_stride_ptr, normalized_shape,
179 normalized_perm, normalized_perm[output_pos])
180 && ((normalized_shape[normalized_perm[output_pos]] == 1)
181 || (output_pos > 0 && normalized_perm[output_pos] == normalized_perm[output_pos - 1] + 1))) {
182 if (output_pos > 0) {
183 normalized_shape[normalized_perm[output_pos - 1]] *= normalized_shape[normalized_perm[output_pos]];
184 }
185 remove_dimension(normalized_shape, normalized_perm, normalized_input_stride_ptr, normalized_output_stride_ptr,
186 output_dims, output_pos);
187 output_dims -= 1;
188 // When a dimension has been removed, new folds may be possible so check
189 // it again.
190 if (output_pos > 0) {
191 output_pos -= 1;
192 }
193 } else {
194 output_pos += 1;
195 }
196 }
197 // All dimensions are size 1.
198 if (output_pos == 0) {
199 *normalized_num_dims = 1;
200 *normalized_element_size_out = element_size;
201 normalized_perm[0] = 0;
202 normalized_shape[0] = 1;
203 normalized_input_stride[0] = element_size;
204 normalized_output_stride[0] = element_size;
205 return;
206 }
207
208 // If The last input and output dimensions are the same, treat it as one large
209 // element.
210 size_t normalized_element_size = element_size;
211 if (normalized_perm[output_dims - 1] == output_dims - 1) {
212 normalized_element_size = element_size * normalized_shape[output_dims - 1];
213 if (output_dims > 1 && can_dimension_be_removed(normalized_input_stride_ptr, normalized_output_stride_ptr, normalized_shape,
214 normalized_perm, output_dims - 1)) {
215 output_dims -= 1;
216 } else {
217 if (normalized_input_stride != NULL) {
218 normalized_input_stride[output_dims - 1] *= normalized_shape[output_dims - 1];
219 }
220 if (normalized_output_stride != NULL) {
221 normalized_output_stride[normalized_perm[output_dims - 1]] *= normalized_shape[output_dims - 1];
222 }
223 normalized_shape[output_dims - 1] = 1;
224 }
225 }
226 // If input_strides is not provided, calculate it using normalized_shape and normalized_element_size.
227 if (input_stride == NULL) {
228 normalized_input_stride[output_dims - 1] = normalized_element_size;
229 for(size_t i = output_dims - 1; i > 0; --i) {
230 normalized_input_stride[i - 1] = normalized_input_stride[i] * normalized_shape[i];
231 }
232 } else {
233 // Scale input_stride by element size.
234 for (size_t i = 0; i < output_dims; ++i) {
235 normalized_input_stride[i] *= element_size;
236 }
237 }
238 // If output_strides is not provided, calculate it using normalized_shape and normalized_element_size.
239 if (output_stride == NULL) {
240 normalized_output_stride[output_dims - 1] = normalized_element_size;
241 for(size_t i = output_dims - 1; i > 0; --i) {
242 normalized_output_stride[i - 1] = normalized_output_stride[i] * normalized_shape[normalized_perm[i]];
243 }
244 } else {
245 // Scale output_stride by element size.
246 for (size_t i = 0; i < output_dims; ++i) {
247 normalized_output_stride[i] *= element_size;
248 }
249 }
250 *normalized_element_size_out = normalized_element_size;
251 *normalized_num_dims = output_dims;
252}
253