1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include <cpuinfo.h> |
9 | #include <algorithm> |
10 | #include <cassert> |
11 | #include <iomanip> |
12 | #include <iostream> |
13 | #include <numeric> |
14 | |
15 | #include "./OptimizedKernelsAvx2.h" |
16 | #include "fbgemm/Fbgemm.h" |
17 | |
18 | namespace fbgemm { |
19 | |
20 | template <typename T, typename accT, int SPATIAL_DIM> |
21 | PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col( |
22 | const conv_param_t<SPATIAL_DIM>& conv_p, |
23 | const T* sdata, |
24 | inpType* pmat, |
25 | int32_t a_zero_pt, |
26 | int32_t* row_offset, |
27 | bool b_symmetric, |
28 | const BlockingFactors* params) |
29 | : PackMatrix<PackAWithIm2Col<T, accT, SPATIAL_DIM>, T, accT>( |
30 | conv_p.MB * |
31 | std::accumulate( |
32 | conv_p.OUT_DIM.begin(), |
33 | conv_p.OUT_DIM.end(), |
34 | 1, |
35 | std::multiplies<int>()), |
36 | std::accumulate( |
37 | conv_p.K.begin(), |
38 | conv_p.K.end(), |
39 | 1, |
40 | std::multiplies<int>()) * |
41 | conv_p.IC, |
42 | pmat, |
43 | conv_p.G, |
44 | params), |
45 | conv_p_(conv_p), |
46 | sdata_(sdata), |
47 | a_zero_pt_(a_zero_pt) { |
48 | if (!cpuinfo_initialize()) { |
49 | throw std::runtime_error("Failed to initialize cpuinfo!" ); |
50 | } |
51 | if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && |
52 | !fbgemmHasAvx2Support())) { |
53 | assert(0 && "unknown architecure" ); |
54 | } |
55 | |
56 | if (params) { |
57 | BaseType::brow_ = params->MCB; |
58 | BaseType::bcol_ = params->KCB; |
59 | row_interleave_B_ = params->ROW_INTERLEAVE; |
60 | } else { |
61 | const inst_set_t isa = fbgemmInstructionSet(); |
62 | switch (isa) { |
63 | case inst_set_t::avx512_vnni: |
64 | std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) = |
65 | PackingTraits<T, accT, inst_set_t::avx512_vnni>:: |
66 | getMatrixPackAParams(); |
67 | break; |
68 | |
69 | case inst_set_t::avx512_vnni_ymm: |
70 | std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) = |
71 | PackingTraits<T, accT, inst_set_t::avx512_vnni_ymm>:: |
72 | getMatrixPackAParams(); |
73 | break; |
74 | |
75 | case inst_set_t::avx512: |
76 | std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) = |
77 | PackingTraits<T, accT, inst_set_t::avx512>::getMatrixPackAParams(); |
78 | break; |
79 | |
80 | case inst_set_t::avx512_ymm: |
81 | std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) = |
82 | PackingTraits<T, accT, inst_set_t::avx512_ymm>:: |
83 | getMatrixPackAParams(); |
84 | break; |
85 | |
86 | case inst_set_t::avx2: |
87 | std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) = |
88 | PackingTraits<T, accT, inst_set_t::avx2>::getMatrixPackAParams(); |
89 | break; |
90 | |
91 | default: |
92 | assert(0 && "unknown architecure" ); |
93 | throw std::runtime_error("unknown architecure" ); |
94 | } |
95 | } |
96 | |
97 | if (BaseType::numCols() % conv_p.G != 0) { |
98 | throw std::runtime_error( |
99 | "groups = " + std::to_string(conv_p.G) + |
100 | " does not divide numCols = " + std::to_string(BaseType::numCols())); |
101 | } |
102 | if (pmat) { |
103 | BaseType::buf_ = pmat; |
104 | } else { |
105 | BaseType::bufAllocatedHere_ = true; |
106 | BaseType::buf_ = static_cast<T*>( |
107 | fbgemmAlignedAlloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T))); |
108 | // aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T))); |
109 | } |
110 | if (!b_symmetric) { |
111 | if (row_offset) { |
112 | rowOffsetAllocatedHere = false; |
113 | row_offset_ = row_offset; |
114 | } else { |
115 | rowOffsetAllocatedHere = true; |
116 | row_offset_ = static_cast<int32_t*>( |
117 | fbgemmAlignedAlloc(64, BaseType::brow_ * sizeof(int32_t))); |
118 | } |
119 | } |
120 | } |
121 | |
122 | template <int SPATIAL_DIM, int BCOL> |
123 | void pack_a_with_im2col_opt( |
124 | const conv_param_t<SPATIAL_DIM>& conv_p, |
125 | const block_type_t& block, |
126 | const uint8_t* sdata, |
127 | uint8_t* out, |
128 | int32_t a_zero_pt, |
129 | int32_t* row_offset_buf, |
130 | int COL_SIZE, |
131 | int COL_P_SIZE, |
132 | bool row_offset_acc) { |
133 | constexpr int IC = 3; |
134 | int IN_DIM_H = conv_p.IN_DIM[0]; |
135 | int IN_DIM_W = conv_p.IN_DIM[1]; |
136 | int K_H = conv_p.K[0]; |
137 | int K_W = conv_p.K[1]; |
138 | constexpr int STRIDE_H = 2; |
139 | constexpr int STRIDE_W = 2; |
140 | int PAD_H = conv_p.pad[0]; |
141 | int PAD_W = conv_p.pad[1]; |
142 | int OUT_DIM_H = conv_p.OUT_DIM[0]; |
143 | int OUT_DIM_W = conv_p.OUT_DIM[1]; |
144 | int OUT_DIM_HW = OUT_DIM_H * OUT_DIM_W; |
145 | |
146 | for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { |
147 | int n = i / OUT_DIM_HW; |
148 | int hw = i % OUT_DIM_HW; |
149 | int w = hw % OUT_DIM_W; |
150 | int h = hw / OUT_DIM_W; |
151 | |
152 | // j refers to column index within block |
153 | int j = 0; |
154 | // r and s iterate over K_H and K_W, respectively |
155 | for (int r = 0; r < K_H; ++r) { |
156 | int h_in = -PAD_H + h * STRIDE_H + r; |
157 | if (h_in < 0 || h_in >= IN_DIM_H) { |
158 | // Short-circuit if h_in is in padding. |
159 | std::memset( |
160 | out + (i - block.row_start) * BCOL + j, |
161 | a_zero_pt, |
162 | sizeof(uint8_t) * K_W * IC); |
163 | j += K_W * IC; |
164 | continue; |
165 | } |
166 | |
167 | int s = 0; |
168 | // left_pad_len : the number of spatial pixels we need to pad at the |
169 | // beginning |
170 | int left_pad_len = PAD_W - w * STRIDE_W; |
171 | if (left_pad_len > 0) { |
172 | std::memset( |
173 | out + (i - block.row_start) * BCOL + j, |
174 | a_zero_pt, |
175 | sizeof(uint8_t) * left_pad_len * IC); |
176 | s += left_pad_len; |
177 | } |
178 | |
179 | // mid_len : the number of spatial pixels that we handle normally |
180 | // (no padding) |
181 | int mid_len = std::min(IN_DIM_W + PAD_W - w * STRIDE_W, K_W) - s; |
182 | std::memcpy( |
183 | out + (i - block.row_start) * BCOL + j + s * IC, |
184 | sdata + |
185 | ((n * IN_DIM_H + h_in) * IN_DIM_W + -PAD_W + w * STRIDE_W + s) * |
186 | IC, |
187 | sizeof(uint8_t) * mid_len * IC); |
188 | s += mid_len; |
189 | |
190 | // right_pad_len : the number of spatial pixels we need to pad at the end |
191 | int right_pad_len = K_W - s; |
192 | if (right_pad_len > 0) { |
193 | std::memset( |
194 | out + (i - block.row_start) * BCOL + j + s * IC, |
195 | a_zero_pt, |
196 | sizeof(uint8_t) * right_pad_len * IC); |
197 | } |
198 | j += K_W * IC; |
199 | } // r loop |
200 | |
201 | // zero fill |
202 | // Please see the comment in PackAMatrix.cc for zero vs zero_pt fill. |
203 | if (COL_P_SIZE - COL_SIZE > 0) { |
204 | std::memset( |
205 | &out[(i - block.row_start) * BCOL + COL_SIZE], |
206 | 0, |
207 | sizeof(uint8_t) * COL_P_SIZE - COL_SIZE); |
208 | } |
209 | |
210 | if (row_offset_buf) { |
211 | int32_t row_sum = |
212 | row_offset_acc ? row_offset_buf[i - block.row_start] : 0; |
213 | row_sum += reduceAvx2(out + (i - block.row_start) * BCOL, COL_SIZE); |
214 | row_offset_buf[i - block.row_start] = row_sum; |
215 | } |
216 | } |
217 | } |
218 | |
219 | template <typename T, typename accT, int SPATIAL_DIM> |
220 | void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) { |
221 | block_type_t block_p = { |
222 | block.row_start, |
223 | block.row_size, |
224 | block.col_start, |
225 | (block.col_size + row_interleave_B_ - 1) / row_interleave_B_ * |
226 | row_interleave_B_}; |
227 | BaseType::packedBlock(block_p); |
228 | T* out = BaseType::getBuf(); |
229 | // accumulate into row offset? |
230 | bool row_offset_acc = |
231 | (block.col_start % (this->numCols() / this->numGroups())) != 0; |
232 | int32_t* row_offset_buf = getRowOffsetBuffer(); |
233 | |
234 | bool point_wise = true; |
235 | for (int d = 0; d < SPATIAL_DIM; ++d) { |
236 | if (conv_p_.K[d] != 1 || conv_p_.pad[d] != 0 || conv_p_.stride[d] != 1 || |
237 | conv_p_.dilation[d] != 1) { |
238 | point_wise = false; |
239 | break; |
240 | } |
241 | } |
242 | for (int d = SPATIAL_DIM; d < SPATIAL_DIM * 2; ++d) { |
243 | if (conv_p_.pad[d] != 0) { |
244 | point_wise = false; |
245 | break; |
246 | } |
247 | } |
248 | |
249 | // reduceAvx2 only written for T == uint8_t |
250 | static_assert( |
251 | std::is_same<T, uint8_t>::value, |
252 | "PackAWithIm2Col<T, accT>::pack only works for T == uint8_t" ); |
253 | if (point_wise) { |
254 | int32_t ld = this->numCols(); |
255 | if (row_offset_buf) { |
256 | for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { |
257 | int buf_idx = i - block.row_start; |
258 | memcpy( |
259 | out + buf_idx * BaseType::blockColSize(), |
260 | sdata_ + i * ld + block.col_start, |
261 | block.col_size * sizeof(T)); |
262 | // zero fill |
263 | for (int j = block.col_size; j < block_p.col_size; ++j) { |
264 | out[buf_idx * BaseType::blockColSize() + j] = 0; |
265 | } |
266 | int32_t row_sum = |
267 | row_offset_acc ? row_offset_buf[i - block.row_start] : 0; |
268 | row_sum += |
269 | reduceAvx2(sdata_ + i * ld + block.col_start, block.col_size); |
270 | row_offset_buf[i - block.row_start] = row_sum; |
271 | } |
272 | } else { |
273 | for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { |
274 | int buf_idx = i - block.row_start; |
275 | memcpy( |
276 | out + buf_idx * BaseType::blockColSize(), |
277 | sdata_ + i * ld + block.col_start, |
278 | block.col_size * sizeof(T)); |
279 | // zero fill |
280 | for (int j = block.col_size; j < block_p.col_size; ++j) { |
281 | out[buf_idx * BaseType::blockColSize() + j] = 0; |
282 | } |
283 | } |
284 | } |
285 | |
286 | return; |
287 | } |
288 | |
289 | int ic_per_group = conv_p_.IC / conv_p_.G; |
290 | |
291 | if (!conv_p_.transposed && SPATIAL_DIM == 2 && conv_p_.IC == 3 && |
292 | conv_p_.G == 1 && conv_p_.stride[0] == 2 && conv_p_.stride[1] == 2 && |
293 | block.col_start == 0 && conv_p_.pad[0] == ((conv_p_.K[0] - 1) / 2) && |
294 | conv_p_.pad[1] == ((conv_p_.K[1] - 1) / 2) && |
295 | block_p.col_size <= BaseType::blockColSize() && |
296 | conv_p_.dilation[0] == 1 && conv_p_.dilation[1] == 1 && |
297 | std::is_same<T, uint8_t>::value) { |
298 | if (BaseType::blockColSize() == 256) { |
299 | pack_a_with_im2col_opt<SPATIAL_DIM, 256>( |
300 | conv_p_, |
301 | block, |
302 | reinterpret_cast<const uint8_t*>(sdata_), |
303 | reinterpret_cast<uint8_t*>(out), |
304 | a_zero_pt_, |
305 | row_offset_buf, |
306 | block.col_size, |
307 | block_p.col_size, |
308 | row_offset_acc); |
309 | return; |
310 | } else if (BaseType::blockColSize() == 512) { |
311 | pack_a_with_im2col_opt<SPATIAL_DIM, 512>( |
312 | conv_p_, |
313 | block, |
314 | reinterpret_cast<const uint8_t*>(sdata_), |
315 | reinterpret_cast<uint8_t*>(out), |
316 | a_zero_pt_, |
317 | row_offset_buf, |
318 | block.col_size, |
319 | block_p.col_size, |
320 | row_offset_acc); |
321 | return; |
322 | } |
323 | } |
324 | if (conv_p_.transposed) { |
325 | for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { |
326 | if (SPATIAL_DIM == 1) { // static if |
327 | int n = i / (conv_p_.OUT_DIM[0]); |
328 | int ow = i % (conv_p_.OUT_DIM[0]); |
329 | for (int j = block.col_start; |
330 | j < block.col_start + block.col_size + ic_per_group - 1; |
331 | j += ic_per_group) { |
332 | int j_blk_id = j / ic_per_group; |
333 | // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC ) |
334 | int j_blk_start = std::max(j_blk_id * ic_per_group, block.col_start); |
335 | int j_blk_end = std::min( |
336 | (j_blk_id + 1) * ic_per_group, block.col_start + block.col_size); |
337 | if (j_blk_start >= j_blk_end) { |
338 | break; |
339 | } |
340 | |
341 | int grs = j / ic_per_group; |
342 | int s = grs % conv_p_.K[0]; |
343 | int g = grs / conv_p_.K[0]; |
344 | |
345 | int w = ow + conv_p_.pad[0] - s * conv_p_.dilation[0]; |
346 | int w_in = w / conv_p_.stride[0]; |
347 | if (w_in * conv_p_.stride[0] == w && w_in >= 0 && |
348 | w_in < conv_p_.IN_DIM[0]) { |
349 | std::memcpy( |
350 | out + (i - block.row_start) * BaseType::blockColSize() + |
351 | j_blk_start - block.col_start, |
352 | sdata_ + (n * conv_p_.IN_DIM[0] + w_in) * conv_p_.IC + |
353 | g * ic_per_group + (j_blk_start % ic_per_group), |
354 | sizeof(T) * (j_blk_end - j_blk_start)); |
355 | } else { |
356 | // Please note that padding for convolution should be filled with |
357 | // zero_pt |
358 | std::memset( |
359 | out + (i - block.row_start) * BaseType::blockColSize() + |
360 | (j_blk_start - block.col_start), |
361 | a_zero_pt_, |
362 | sizeof(T) * (j_blk_end - j_blk_start)); |
363 | } |
364 | } |
365 | |
366 | } else if (SPATIAL_DIM == 2) { // static if |
367 | int n = i / (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1]); |
368 | int hw = i % (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1]); |
369 | int ow = hw % conv_p_.OUT_DIM[1]; |
370 | int oh = hw / conv_p_.OUT_DIM[1]; |
371 | for (int j = block.col_start; |
372 | j < block.col_start + block.col_size + ic_per_group - 1; |
373 | j += ic_per_group) { |
374 | int j_blk_id = j / ic_per_group; |
375 | // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC ) |
376 | int j_blk_start = std::max(j_blk_id * ic_per_group, block.col_start); |
377 | int j_blk_end = std::min( |
378 | (j_blk_id + 1) * ic_per_group, block.col_start + block.col_size); |
379 | if (j_blk_start >= j_blk_end) { |
380 | break; |
381 | } |
382 | |
383 | int grs = j / ic_per_group; |
384 | int s = grs % conv_p_.K[1]; |
385 | int r = grs / conv_p_.K[1] % conv_p_.K[0]; |
386 | int g = grs / conv_p_.K[1] / conv_p_.K[0]; |
387 | |
388 | int h = oh + conv_p_.pad[0] - r * conv_p_.dilation[0]; |
389 | int w = ow + conv_p_.pad[1] - s * conv_p_.dilation[1]; |
390 | |
391 | int h_in = h / conv_p_.stride[0]; |
392 | int w_in = w / conv_p_.stride[1]; |
393 | |
394 | if (h_in * conv_p_.stride[0] == h && h_in >= 0 && |
395 | h_in < conv_p_.IN_DIM[0] && w_in * conv_p_.stride[1] == w && |
396 | w_in >= 0 && w_in < conv_p_.IN_DIM[1]) { |
397 | std::memcpy( |
398 | out + (i - block.row_start) * BaseType::blockColSize() + |
399 | j_blk_start - block.col_start, |
400 | sdata_ + |
401 | ((n * conv_p_.IN_DIM[0] + h_in) * conv_p_.IN_DIM[1] + |
402 | w_in) * |
403 | conv_p_.IC + |
404 | g * ic_per_group + (j_blk_start % ic_per_group), |
405 | sizeof(T) * (j_blk_end - j_blk_start)); |
406 | } else { |
407 | // Please note that padding for convolution should be filled with |
408 | // zero_pt |
409 | std::memset( |
410 | out + (i - block.row_start) * BaseType::blockColSize() + |
411 | (j_blk_start - block.col_start), |
412 | a_zero_pt_, |
413 | sizeof(T) * (j_blk_end - j_blk_start)); |
414 | } |
415 | } |
416 | } else if (SPATIAL_DIM == 3) { // static if |
417 | int n = |
418 | i / (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1] * conv_p_.OUT_DIM[2]); |
419 | int thw = |
420 | i % (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1] * conv_p_.OUT_DIM[2]); |
421 | int ow = thw % conv_p_.OUT_DIM[2]; |
422 | int oh = thw / conv_p_.OUT_DIM[2] % conv_p_.OUT_DIM[1]; |
423 | int ot = thw / conv_p_.OUT_DIM[2] / conv_p_.OUT_DIM[1]; |
424 | for (int j = block.col_start; |
425 | j < block.col_start + block.col_size + ic_per_group - 1; |
426 | j += ic_per_group) { |
427 | int j_blk_id = j / ic_per_group; |
428 | // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC ) |
429 | int j_blk_start = std::max(j_blk_id * ic_per_group, block.col_start); |
430 | int j_blk_end = std::min( |
431 | (j_blk_id + 1) * ic_per_group, block.col_start + block.col_size); |
432 | if (j_blk_start >= j_blk_end) { |
433 | break; |
434 | } |
435 | |
436 | int gqrs = j / ic_per_group; |
437 | int s = gqrs % conv_p_.K[2]; |
438 | int r = gqrs / conv_p_.K[2] % conv_p_.K[1]; |
439 | int q = gqrs / conv_p_.K[2] / conv_p_.K[1] % conv_p_.K[0]; |
440 | int g = gqrs / conv_p_.K[2] / conv_p_.K[1] / conv_p_.K[0]; |
441 | |
442 | int t = ot + conv_p_.pad[0] - q * conv_p_.dilation[0]; |
443 | int h = oh + conv_p_.pad[1] - r * conv_p_.dilation[1]; |
444 | int w = ow + conv_p_.pad[2] - s * conv_p_.dilation[2]; |
445 | int t_in = t / conv_p_.stride[0]; |
446 | int h_in = h / conv_p_.stride[1]; |
447 | int w_in = w / conv_p_.stride[2]; |
448 | |
449 | if (t_in * conv_p_.stride[0] == t && t_in >= 0 && |
450 | t_in < conv_p_.IN_DIM[0] && h_in * conv_p_.stride[1] == h && |
451 | h_in >= 0 && h_in < conv_p_.IN_DIM[1] && |
452 | w_in * conv_p_.stride[2] == w && w_in >= 0 && |
453 | w_in < conv_p_.IN_DIM[2]) { |
454 | std::memcpy( |
455 | out + (i - block.row_start) * BaseType::blockColSize() + |
456 | j_blk_start - block.col_start, |
457 | sdata_ + |
458 | (((n * conv_p_.IN_DIM[0] + t_in) * conv_p_.IN_DIM[1] + |
459 | h_in) * |
460 | conv_p_.IN_DIM[2] + |
461 | w_in) * |
462 | conv_p_.IC + |
463 | g * ic_per_group + (j_blk_start % ic_per_group), |
464 | sizeof(T) * (j_blk_end - j_blk_start)); |
465 | } else { |
466 | // Please note that padding for convolution should be filled with |
467 | // zero_pt |
468 | std::memset( |
469 | &out |
470 | [(i - block.row_start) * BaseType::blockColSize() + |
471 | (j_blk_start - block.col_start)], |
472 | a_zero_pt_, |
473 | sizeof(T) * (j_blk_end - j_blk_start)); |
474 | } |
475 | } |
476 | } |
477 | |
478 | // zero fill |
479 | // Please see the comment in PackAMatrix.cc for zero vs zero_pt fill. |
480 | if ((block_p.col_start + block_p.col_size) - |
481 | (block.col_start + block.col_size) > |
482 | 0) { |
483 | std::memset( |
484 | &out |
485 | [(i - block.row_start) * BaseType::blockColSize() + |
486 | (block.col_size)], |
487 | 0, |
488 | sizeof(T) * |
489 | ((block_p.col_start + block_p.col_size) - |
490 | (block.col_start + block.col_size))); |
491 | } |
492 | |
493 | if (row_offset_buf) { |
494 | int32_t row_sum = |
495 | row_offset_acc ? row_offset_buf[i - block.row_start] : 0; |
496 | row_sum += reduceAvx2( |
497 | out + (i - block.row_start) * this->blockColSize(), block.col_size); |
498 | row_offset_buf[i - block.row_start] = row_sum; |
499 | } |
500 | } // for each i |
501 | } else { |
502 | for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { |
503 | if (SPATIAL_DIM == 1) { // static if |
504 | int n = i / (conv_p_.OUT_DIM[0]); |
505 | int w = i % (conv_p_.OUT_DIM[0]); |
506 | for (int j = block.col_start; |
507 | j < block.col_start + block.col_size + ic_per_group - 1; |
508 | j += ic_per_group) { |
509 | int j_blk_id = j / ic_per_group; |
510 | // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC ) |
511 | int j_blk_start = std::max(j_blk_id * ic_per_group, block.col_start); |
512 | int j_blk_end = std::min( |
513 | (j_blk_id + 1) * ic_per_group, block.col_start + block.col_size); |
514 | if (j_blk_start >= j_blk_end) { |
515 | break; |
516 | } |
517 | |
518 | int grs = j / ic_per_group; |
519 | int s = grs % conv_p_.K[0]; |
520 | int g = grs / conv_p_.K[0]; |
521 | |
522 | int w_in = |
523 | -conv_p_.pad[0] + w * conv_p_.stride[0] + s * conv_p_.dilation[0]; |
524 | if (w_in < 0 || w_in >= conv_p_.IN_DIM[0]) { |
525 | // Please note that padding for convolution should be filled with |
526 | // zero_pt |
527 | std::memset( |
528 | out + (i - block.row_start) * BaseType::blockColSize() + |
529 | (j_blk_start - block.col_start), |
530 | a_zero_pt_, |
531 | sizeof(T) * (j_blk_end - j_blk_start)); |
532 | } else { |
533 | std::memcpy( |
534 | out + (i - block.row_start) * BaseType::blockColSize() + |
535 | j_blk_start - block.col_start, |
536 | sdata_ + (n * conv_p_.IN_DIM[0] + w_in) * conv_p_.IC + |
537 | g * ic_per_group + (j_blk_start % ic_per_group), |
538 | sizeof(T) * (j_blk_end - j_blk_start)); |
539 | } |
540 | } |
541 | |
542 | } else if (SPATIAL_DIM == 2) { // static if |
543 | int n = i / (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1]); |
544 | int hw = i % (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1]); |
545 | int w = hw % conv_p_.OUT_DIM[1]; |
546 | int h = hw / conv_p_.OUT_DIM[1]; |
547 | for (int j = block.col_start; |
548 | j < block.col_start + block.col_size + ic_per_group - 1; |
549 | j += ic_per_group) { |
550 | int j_blk_id = j / ic_per_group; |
551 | // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC ) |
552 | int j_blk_start = std::max(j_blk_id * ic_per_group, block.col_start); |
553 | int j_blk_end = std::min( |
554 | (j_blk_id + 1) * ic_per_group, block.col_start + block.col_size); |
555 | if (j_blk_start >= j_blk_end) { |
556 | break; |
557 | } |
558 | |
559 | int grs = j / ic_per_group; |
560 | int s = grs % conv_p_.K[1]; |
561 | int r = grs / conv_p_.K[1] % conv_p_.K[0]; |
562 | int g = grs / conv_p_.K[1] / conv_p_.K[0]; |
563 | |
564 | int h_in = |
565 | -conv_p_.pad[0] + h * conv_p_.stride[0] + r * conv_p_.dilation[0]; |
566 | int w_in = |
567 | -conv_p_.pad[1] + w * conv_p_.stride[1] + s * conv_p_.dilation[1]; |
568 | |
569 | if (h_in < 0 || h_in >= conv_p_.IN_DIM[0] || w_in < 0 || |
570 | w_in >= conv_p_.IN_DIM[1]) { |
571 | // Please note that padding for convolution should be filled with |
572 | // zero_pt |
573 | std::memset( |
574 | out + (i - block.row_start) * BaseType::blockColSize() + |
575 | (j_blk_start - block.col_start), |
576 | a_zero_pt_, |
577 | sizeof(T) * (j_blk_end - j_blk_start)); |
578 | } else { |
579 | int chn_start_idx = j_blk_start % ic_per_group; |
580 | int src_offset = |
581 | ((n * conv_p_.IN_DIM[0] + h_in) * conv_p_.IN_DIM[1] + w_in) * |
582 | conv_p_.IC + |
583 | g * ic_per_group + chn_start_idx; |
584 | // fast path |
585 | // Copy across pixels of input width if we can. We can only do this |
586 | // if the following conditions are met. 1) If the number of groups |
587 | // is 1. For number of groups > 1, im2col |
588 | // doesn't copy data across groups. |
589 | // 2) If dilation is 1. For dilation > 1, copying from input |
590 | // across channels is not sequential. |
591 | // 3) For copy from the last channel (end of filter or |
592 | // end of image width) for the current filter, |
593 | // only copy if we have enough in the current channel. |
594 | // |
595 | if (conv_p_.G == 1 && conv_p_.dilation[1] == 1 && |
596 | ((s < (conv_p_.K[1] - 1) && w_in < (conv_p_.IN_DIM[1] - 1)) || |
597 | ((chn_start_idx + block.col_size) <= ic_per_group))) { |
598 | // left edge adjustment with s |
599 | j_blk_end = std::min( |
600 | (j_blk_id + conv_p_.K[1] - s) * ic_per_group, |
601 | block.col_start + block.col_size); |
602 | // right edge adjustment with w_in |
603 | j_blk_end = std::min( |
604 | (j_blk_id + conv_p_.IN_DIM[1] - w_in) * ic_per_group, |
605 | j_blk_end); |
606 | j += j_blk_end - j_blk_start - ic_per_group; |
607 | } |
608 | std::memcpy( |
609 | out + (i - block.row_start) * BaseType::blockColSize() + |
610 | j_blk_start - block.col_start, |
611 | sdata_ + src_offset, |
612 | sizeof(T) * (j_blk_end - j_blk_start)); |
613 | } |
614 | } |
615 | } else if (SPATIAL_DIM == 3) { // static if |
616 | int n = |
617 | i / (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1] * conv_p_.OUT_DIM[2]); |
618 | int thw = |
619 | i % (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1] * conv_p_.OUT_DIM[2]); |
620 | int w = thw % conv_p_.OUT_DIM[2]; |
621 | int h = thw / conv_p_.OUT_DIM[2] % conv_p_.OUT_DIM[1]; |
622 | int t = thw / conv_p_.OUT_DIM[2] / conv_p_.OUT_DIM[1]; |
623 | for (int j = block.col_start; |
624 | j < block.col_start + block.col_size + ic_per_group - 1; |
625 | j += ic_per_group) { |
626 | int j_blk_id = j / ic_per_group; |
627 | // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC ) |
628 | int j_blk_start = std::max(j_blk_id * ic_per_group, block.col_start); |
629 | int j_blk_end = std::min( |
630 | (j_blk_id + 1) * ic_per_group, block.col_start + block.col_size); |
631 | if (j_blk_start >= j_blk_end) { |
632 | break; |
633 | } |
634 | |
635 | int gqrs = j / ic_per_group; |
636 | int s = gqrs % conv_p_.K[2]; |
637 | int r = gqrs / conv_p_.K[2] % conv_p_.K[1]; |
638 | int q = gqrs / conv_p_.K[2] / conv_p_.K[1] % conv_p_.K[0]; |
639 | int g = gqrs / conv_p_.K[2] / conv_p_.K[1] / conv_p_.K[0]; |
640 | |
641 | int t_in = |
642 | -conv_p_.pad[0] + t * conv_p_.stride[0] + q * conv_p_.dilation[0]; |
643 | int h_in = |
644 | -conv_p_.pad[1] + h * conv_p_.stride[1] + r * conv_p_.dilation[1]; |
645 | int w_in = |
646 | -conv_p_.pad[2] + w * conv_p_.stride[2] + s * conv_p_.dilation[2]; |
647 | |
648 | if (t_in < 0 || t_in >= conv_p_.IN_DIM[0] || h_in < 0 || |
649 | h_in >= conv_p_.IN_DIM[1] || w_in < 0 || |
650 | w_in >= conv_p_.IN_DIM[2]) { |
651 | // Please note that padding for convolution should be filled with |
652 | // zero_pt |
653 | std::memset( |
654 | &out |
655 | [(i - block.row_start) * BaseType::blockColSize() + |
656 | (j_blk_start - block.col_start)], |
657 | a_zero_pt_, |
658 | sizeof(T) * (j_blk_end - j_blk_start)); |
659 | } else { |
660 | std::memcpy( |
661 | out + (i - block.row_start) * BaseType::blockColSize() + |
662 | j_blk_start - block.col_start, |
663 | sdata_ + |
664 | (((n * conv_p_.IN_DIM[0] + t_in) * conv_p_.IN_DIM[1] + |
665 | h_in) * |
666 | conv_p_.IN_DIM[2] + |
667 | w_in) * |
668 | conv_p_.IC + |
669 | g * ic_per_group + (j_blk_start % ic_per_group), |
670 | sizeof(T) * (j_blk_end - j_blk_start)); |
671 | } |
672 | } |
673 | } |
674 | |
675 | // zero fill |
676 | // Please see the comment in PackAMatrix.cc for zero vs zero_pt fill. |
677 | if ((block_p.col_start + block_p.col_size) - |
678 | (block.col_start + block.col_size) > |
679 | 0) { |
680 | std::memset( |
681 | &out |
682 | [(i - block.row_start) * BaseType::blockColSize() + |
683 | (block.col_size)], |
684 | 0, |
685 | sizeof(T) * |
686 | ((block_p.col_start + block_p.col_size) - |
687 | (block.col_start + block.col_size))); |
688 | } |
689 | |
690 | if (row_offset_buf) { |
691 | int32_t row_sum = |
692 | row_offset_acc ? row_offset_buf[i - block.row_start] : 0; |
693 | row_sum += reduceAvx2( |
694 | out + (i - block.row_start) * this->blockColSize(), block.col_size); |
695 | row_offset_buf[i - block.row_start] = row_sum; |
696 | } |
697 | } // for each i |
698 | } |
699 | } |
700 | |
701 | template <typename T, typename accT, int SPATIAL_DIM> |
702 | void PackAWithIm2Col<T, accT, SPATIAL_DIM>::printPackedMatrix( |
703 | std::string name) { |
704 | std::cout << name << ":" |
705 | << "[" << BaseType::numPackedRows() << ", " |
706 | << BaseType::numPackedCols() << "]" << std::endl; |
707 | |
708 | T* out = BaseType::getBuf(); |
709 | for (auto r = 0; r < BaseType::numPackedRows(); ++r) { |
710 | for (auto c = 0; c < BaseType::numPackedCols(); ++c) { |
711 | T val = out[r * BaseType::blockColSize() + c]; |
712 | if (std::is_integral<T>::value) { |
713 | // cast to int64 because cout doesn't print int8_t type directly |
714 | std::cout << std::setw(5) << static_cast<int64_t>(val) << " " ; |
715 | } else { |
716 | std::cout << std::setw(5) << val << " " ; |
717 | } |
718 | } |
719 | std::cout << std::endl; |
720 | } |
721 | std::cout << std::endl; |
722 | } |
723 | |
724 | template <typename T, typename accT, int SPATIAL_DIM> |
725 | int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize( |
726 | const BlockingFactors* params) { |
727 | if (cpuinfo_initialize()) { |
728 | if (params) { |
729 | return params->MCB; |
730 | } else { |
731 | if (fbgemmHasAvx512VnniSupport()) { |
732 | return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; |
733 | } else if (fbgemmHasAvx512Support()) { |
734 | return PackingTraits<T, accT, inst_set_t::avx512>::MCB; |
735 | } else if (fbgemmHasAvx2Support()) { |
736 | return PackingTraits<T, accT, inst_set_t::avx2>::MCB; |
737 | } else { |
738 | // TODO: Have default slower path |
739 | assert(0 && "unsupported architecture" ); |
740 | return -1; |
741 | } |
742 | } |
743 | } else { |
744 | throw std::runtime_error("Failed to initialize cpuinfo!" ); |
745 | } |
746 | } |
747 | |
748 | template class PackAWithIm2Col<uint8_t, int32_t, 1>; |
749 | template class PackAWithIm2Col<uint8_t, int16_t, 1>; |
750 | template class PackAWithIm2Col<uint8_t, int32_t, 2>; |
751 | template class PackAWithIm2Col<uint8_t, int16_t, 2>; |
752 | template class PackAWithIm2Col<uint8_t, int32_t, 3>; |
753 | template class PackAWithIm2Col<uint8_t, int16_t, 3>; |
754 | |
755 | } // namespace fbgemm |
756 | |