1#include <string.h>
2#include <stdlib.h>
3#include <stdbool.h>
4#include <stdint.h>
5
6#include <nnpack.h>
7#include <nnpack/macros.h>
8#include <nnpack/utils.h>
9#include <nnpack/pooling.h>
10
11#include <nnpack/validation.h>
12
13
14struct NNP_CACHE_ALIGN pooling_context {
15 nnp_pooling_function pooling_function;
16 const float* input_pointer;
17 float* output_pointer;
18
19 size_t channels;
20 struct nnp_size input_size;
21 struct nnp_padding input_padding;
22 struct nnp_size output_size;
23 struct nnp_size pooling_size;
24 struct nnp_size pooling_stride;
25};
26
27static void compute_max_pooling_forward__generic(
28 const float *restrict input_pointer,
29 float *restrict output_pointer,
30 size_t input_height,
31 size_t input_width,
32 size_t padding_top,
33 size_t padding_left,
34 size_t output_height,
35 size_t output_width,
36 uint32_t stride_height,
37 uint32_t stride_width,
38 uint32_t pooling_height,
39 uint32_t pooling_width)
40{
41 const float (*input)[input_width] = (const float(*)[input_width]) input_pointer;
42 float (*output)[output_width] = (float(*)[output_width]) output_pointer;
43
44 for (size_t y = 0; y < output_height; y++) {
45 for (size_t x = 0; x < output_width; x++) {
46 float v = -__builtin_inff();
47 for (size_t i = 0; i < pooling_height; i++) {
48 const size_t s = y * stride_height + i - padding_top;
49 if (s < input_height) {
50 for (size_t j = 0; j < pooling_width; j++) {
51 const size_t t = x * stride_width + j - padding_left;
52 if (t < input_width) {
53 v = maxf(input[s][t], v);
54 }
55 }
56 }
57 }
58 output[y][x] = v;
59 }
60 }
61}
62
63#if NNP_BACKEND_X86_64
64static void compute_max_pooling_forward_2x2_2x2__avx2(
65 const float *restrict input_pointer,
66 float *restrict output_pointer,
67 size_t input_height,
68 size_t input_width,
69 size_t padding_top,
70 size_t padding_left,
71 size_t output_height,
72 size_t output_width,
73 uint32_t stride_height,
74 uint32_t stride_width,
75 uint32_t pooling_height,
76 uint32_t pooling_width)
77{
78 const struct nnp_size output_tile = {
79 .height = 1,
80 .width = 8
81 };
82 const struct nnp_size input_tile = {
83 .height = 2,
84 .width = 16,
85 };
86
87 const float (*input)[input_width] = (const float(*)[input_width]) input_pointer;
88 float (*output)[output_width] = (float(*)[output_width]) output_pointer;
89
90 for (size_t y = 0; y < output_height; y += output_tile.height) {
91 const size_t input_y = min(doz(y * stride_height, padding_top), input_height);
92 const size_t input_row_offset = doz(padding_top, y);
93 const size_t input_row_count = min(input_tile.height, doz(input_height, input_y));
94 const size_t output_row_count = min(output_tile.height, output_height - y);
95 for (size_t x = 0; x < output_width; x += output_tile.width) {
96 const size_t input_x = min(doz(x * stride_width, padding_left), input_width);
97 const size_t input_column_offset = doz(padding_left, x);
98 const size_t input_column_count = min(input_tile.width, doz(input_width, input_x));
99 const size_t output_column_count = min(output_tile.width, output_width - x);
100 nnp_maxpool_2x2_2x2__avx2(
101 &input[input_y][input_x],
102 &output[y][x],
103 input_width,
104 input_row_offset,
105 input_row_count,
106 input_column_offset,
107 input_column_count,
108 output_column_count);
109 }
110 }
111}
112#endif
113
114static void compute_pooling_output(
115 const struct pooling_context context[restrict static 1],
116 size_t sample, size_t channel)
117{
118 const size_t channels = context->channels;
119 const struct nnp_size input_size = context->input_size;
120 const struct nnp_padding input_padding = context->input_padding;
121 const struct nnp_size output_size = context->output_size;
122 const struct nnp_size pooling_stride = context->pooling_stride;
123 const struct nnp_size pooling_size = context->pooling_size;
124 const nnp_pooling_function pooling_function = context->pooling_function;
125
126 const float (*input)[channels][input_size.height * input_size.width] =
127 (const float(*)[channels][input_size.height * input_size.width]) context->input_pointer;
128 float (*output)[channels][output_size.height * output_size.width] =
129 (float(*)[channels][output_size.height * output_size.width]) context->output_pointer;
130
131 pooling_function(
132 input[sample][channel],
133 output[sample][channel],
134 input_size.height, input_size.width,
135 input_padding.top, input_padding.left,
136 output_size.height, output_size.width,
137 pooling_stride.height, pooling_stride.width,
138 pooling_size.height, pooling_size.width);
139}
140
141enum nnp_status nnp_max_pooling_output(
142 size_t batch_size,
143 size_t channels,
144 struct nnp_size input_size,
145 struct nnp_padding input_padding,
146 struct nnp_size pooling_size,
147 struct nnp_size pooling_stride,
148 const float input[],
149 float output[],
150 pthreadpool_t threadpool)
151{
152 enum nnp_status status = validate_pooling_arguments(
153 batch_size, channels,
154 input_size, input_padding,
155 pooling_size, pooling_stride);
156 if (status != nnp_status_success) {
157 return status;
158 }
159
160 const struct nnp_size output_size = {
161 .height = divide_round_up(doz(input_padding.top + input_size.height + input_padding.bottom, pooling_size.height), pooling_stride.height) + 1,
162 .width = divide_round_up(doz(input_padding.left + input_size.width + input_padding.right, pooling_size.width), pooling_stride.width) + 1,
163 };
164
165 struct pooling_context pooling_context = {
166 .channels = channels,
167 .input_pointer = input,
168 .input_padding = input_padding,
169 .output_pointer = output,
170 .input_size = input_size,
171 .output_size = output_size,
172 .pooling_size = pooling_size,
173 .pooling_stride = pooling_stride,
174 .pooling_function = compute_max_pooling_forward__generic,
175 };
176
177 #if NNP_BACKEND_X86_64
178 if ((pooling_stride.height == 2) && (pooling_stride.width == 2) && (pooling_size.height == 2) && (pooling_size.width == 2)) {
179 pooling_context.pooling_function = compute_max_pooling_forward_2x2_2x2__avx2;
180 }
181 #endif
182
183 pthreadpool_parallelize_2d(threadpool,
184 (pthreadpool_task_2d_t) compute_pooling_output,
185 &pooling_context,
186 batch_size, channels,
187 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
188
189 return nnp_status_success;
190}
191