1// Copyright 2022 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <stdbool.h>
7#include <stddef.h>
8#include <string.h>
9
10#include <xnnpack/math.h>
11
12// Returns true if input stride and output stride are NULL or the expected input/output stride matches the actual input/output stride.
13static bool can_dimension_be_removed(
14 const size_t* input_stride,
15 const size_t* output_stride,
16 const size_t* shape,
17 const size_t* perm,
18 size_t dim) {
19 if (dim == 0 && perm[dim] == 0) {
20 return true;
21 }
22 if (input_stride != NULL && dim > 0) {
23 if (input_stride[dim - 1] != input_stride[dim] * shape[dim]) {
24 return false;
25 }
26 }
27 if (output_stride != NULL && perm[dim] > 0) {
28 if (output_stride[perm[dim] - 1] != output_stride[perm[dim]] * shape[dim]) {
29 return false;
30 }
31 }
32 return true;
33}
34
35// Remove dimension perm[dim] from shape, perm, input & output strides.
36static void remove_dimension(
37 size_t* shape,
38 size_t* perm,
39 size_t* input_stride,
40 size_t* output_stride,
41 size_t num_dims,
42 size_t dim)
43{
44 for (size_t j = perm[dim]; j + 1 < num_dims; ++j) {
45 shape[j] = shape[j + 1];
46 }
47 if (input_stride != NULL) {
48 for (size_t j = max(1, perm[dim]) - 1; j + 1 < num_dims; ++j) {
49 input_stride[j] = input_stride[j + 1];
50 }
51 }
52 if (output_stride != NULL) {
53 for (size_t j = max(1, dim) - 1; j + 1 < num_dims; ++j) {
54 output_stride[j] = output_stride[j + 1];
55 }
56 }
57 for (size_t j = 0; j < num_dims; ++j) {
58 if (perm[j] > perm[dim]) {
59 perm[j] -= 1;
60 }
61 }
62 for (size_t j = dim; j + 1 < num_dims; ++j) {
63 perm[j] = perm[j + 1];
64 }
65}
66void xnn_normalize_transpose_permutation(
67 const size_t num_dims,
68 const size_t element_size,
69 const size_t* perm,
70 const size_t* shape,
71 const size_t* input_stride,
72 const size_t* output_stride,
73 size_t* normalized_num_dims,
74 size_t* normalized_element_size_out,
75 size_t* normalized_perm,
76 size_t* normalized_shape,
77 size_t* normalized_input_stride,
78 size_t* normalized_output_stride)
79{
80 size_t output_dims = num_dims;
81 memcpy(normalized_perm, perm, num_dims * sizeof(size_t));
82 memcpy(normalized_shape, shape, num_dims * sizeof(size_t));
83 size_t* normalized_input_stride_ptr = NULL;
84 size_t* normalized_output_stride_ptr = NULL;
85 if (input_stride != NULL) {
86 memcpy(normalized_input_stride, input_stride, num_dims * sizeof(size_t));
87 normalized_input_stride_ptr = normalized_input_stride;
88 }
89 if (output_stride != NULL) {
90 memcpy(normalized_output_stride, output_stride, num_dims * sizeof(size_t));
91 normalized_output_stride_ptr = normalized_output_stride;
92 }
93
94 size_t output_pos = 0;
95 // Remove dimensions of size 1 and fold dimensions which are adjacent in both input and output tensors.
96 for (; output_pos < output_dims;) {
97 if (can_dimension_be_removed(normalized_input_stride_ptr, normalized_output_stride_ptr, normalized_shape,
98 normalized_perm, normalized_perm[output_pos])
99 && ((normalized_shape[normalized_perm[output_pos]] == 1)
100 || (output_pos > 0 && normalized_perm[output_pos] == normalized_perm[output_pos - 1] + 1))) {
101 if (output_pos > 0) {
102 normalized_shape[normalized_perm[output_pos - 1]] *= normalized_shape[normalized_perm[output_pos]];
103 }
104 remove_dimension(normalized_shape, normalized_perm, normalized_input_stride_ptr, normalized_output_stride_ptr,
105 output_dims, output_pos);
106 output_dims -= 1;
107 // When a dimension has been removed, new folds may be possible so check
108 // it again.
109 if (output_pos > 0) {
110 output_pos -= 1;
111 }
112 } else {
113 output_pos += 1;
114 }
115 }
116 // All dimensions are size 1.
117 if (output_pos == 0) {
118 *normalized_num_dims = 1;
119 *normalized_element_size_out = element_size;
120 normalized_perm[0] = 0;
121 normalized_shape[0] = 1;
122 normalized_input_stride[0] = element_size;
123 normalized_output_stride[0] = element_size;
124 return;
125 }
126
127 // If The last input and output dimensions are the same, treat it as one large
128 // element.
129 size_t normalized_element_size = element_size;
130 if (normalized_perm[output_dims - 1] == output_dims - 1) {
131 normalized_element_size = element_size * normalized_shape[output_dims - 1];
132 if (output_dims > 1 && can_dimension_be_removed(normalized_input_stride_ptr, normalized_output_stride_ptr, normalized_shape,
133 normalized_perm, output_dims - 1)) {
134 output_dims -= 1;
135 } else {
136 if (normalized_input_stride != NULL) {
137 normalized_input_stride[output_dims - 1] *= normalized_shape[output_dims - 1];
138 }
139 if (normalized_output_stride != NULL) {
140 normalized_output_stride[normalized_perm[output_dims - 1]] *= normalized_shape[output_dims - 1];
141 }
142 normalized_shape[output_dims - 1] = 1;
143 }
144 }
145 // If input_strides is not provided, calculate it using normalized_shape and normalized_element_size.
146 if (input_stride == NULL) {
147 normalized_input_stride[output_dims - 1] = normalized_element_size;
148 for(size_t i = output_dims - 1; i > 0; --i) {
149 normalized_input_stride[i - 1] = normalized_input_stride[i] * normalized_shape[i];
150 }
151 } else {
152 // Scale input_stride by element size.
153 for (size_t i = 0; i < output_dims; ++i) {
154 normalized_input_stride[i] *= element_size;
155 }
156 }
157 // If output_strides is not provided, calculate it using normalized_shape and normalized_element_size.
158 if (output_stride == NULL) {
159 normalized_output_stride[output_dims - 1] = normalized_element_size;
160 for(size_t i = output_dims - 1; i > 0; --i) {
161 normalized_output_stride[i - 1] = normalized_output_stride[i] * normalized_shape[normalized_perm[i]];
162 }
163 } else {
164 // Scale output_stride by element size.
165 for (size_t i = 0; i < output_dims; ++i) {
166 normalized_output_stride[i] *= element_size;
167 }
168 }
169 *normalized_element_size_out = normalized_element_size;
170 *normalized_num_dims = output_dims;
171}
172