1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include <cpuinfo.h>
9#include <algorithm>
10#include <cassert>
11#include <iomanip>
12#include <iostream>
13#include <numeric>
14
15#include "./OptimizedKernelsAvx2.h"
16#include "fbgemm/Fbgemm.h"
17
18namespace fbgemm {
19
20template <typename T, typename accT, int SPATIAL_DIM>
21PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col(
22 const conv_param_t<SPATIAL_DIM>& conv_p,
23 const T* sdata,
24 inpType* pmat,
25 int32_t a_zero_pt,
26 int32_t* row_offset,
27 bool b_symmetric,
28 const BlockingFactors* params)
29 : PackMatrix<PackAWithIm2Col<T, accT, SPATIAL_DIM>, T, accT>(
30 conv_p.MB *
31 std::accumulate(
32 conv_p.OUT_DIM.begin(),
33 conv_p.OUT_DIM.end(),
34 1,
35 std::multiplies<int>()),
36 std::accumulate(
37 conv_p.K.begin(),
38 conv_p.K.end(),
39 1,
40 std::multiplies<int>()) *
41 conv_p.IC,
42 pmat,
43 conv_p.G,
44 params),
45 conv_p_(conv_p),
46 sdata_(sdata),
47 a_zero_pt_(a_zero_pt) {
48 if (!cpuinfo_initialize()) {
49 throw std::runtime_error("Failed to initialize cpuinfo!");
50 }
51 if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
52 !fbgemmHasAvx2Support())) {
53 assert(0 && "unknown architecure");
54 }
55
56 if (params) {
57 BaseType::brow_ = params->MCB;
58 BaseType::bcol_ = params->KCB;
59 row_interleave_B_ = params->ROW_INTERLEAVE;
60 } else {
61 const inst_set_t isa = fbgemmInstructionSet();
62 switch (isa) {
63 case inst_set_t::avx512_vnni:
64 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) =
65 PackingTraits<T, accT, inst_set_t::avx512_vnni>::
66 getMatrixPackAParams();
67 break;
68
69 case inst_set_t::avx512_vnni_ymm:
70 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) =
71 PackingTraits<T, accT, inst_set_t::avx512_vnni_ymm>::
72 getMatrixPackAParams();
73 break;
74
75 case inst_set_t::avx512:
76 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) =
77 PackingTraits<T, accT, inst_set_t::avx512>::getMatrixPackAParams();
78 break;
79
80 case inst_set_t::avx512_ymm:
81 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) =
82 PackingTraits<T, accT, inst_set_t::avx512_ymm>::
83 getMatrixPackAParams();
84 break;
85
86 case inst_set_t::avx2:
87 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) =
88 PackingTraits<T, accT, inst_set_t::avx2>::getMatrixPackAParams();
89 break;
90
91 default:
92 assert(0 && "unknown architecure");
93 throw std::runtime_error("unknown architecure");
94 }
95 }
96
97 if (BaseType::numCols() % conv_p.G != 0) {
98 throw std::runtime_error(
99 "groups = " + std::to_string(conv_p.G) +
100 " does not divide numCols = " + std::to_string(BaseType::numCols()));
101 }
102 if (pmat) {
103 BaseType::buf_ = pmat;
104 } else {
105 BaseType::bufAllocatedHere_ = true;
106 BaseType::buf_ = static_cast<T*>(
107 fbgemmAlignedAlloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)));
108 // aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)));
109 }
110 if (!b_symmetric) {
111 if (row_offset) {
112 rowOffsetAllocatedHere = false;
113 row_offset_ = row_offset;
114 } else {
115 rowOffsetAllocatedHere = true;
116 row_offset_ = static_cast<int32_t*>(
117 fbgemmAlignedAlloc(64, BaseType::brow_ * sizeof(int32_t)));
118 }
119 }
120}
121
122template <int SPATIAL_DIM, int BCOL>
123void pack_a_with_im2col_opt(
124 const conv_param_t<SPATIAL_DIM>& conv_p,
125 const block_type_t& block,
126 const uint8_t* sdata,
127 uint8_t* out,
128 int32_t a_zero_pt,
129 int32_t* row_offset_buf,
130 int COL_SIZE,
131 int COL_P_SIZE,
132 bool row_offset_acc) {
133 constexpr int IC = 3;
134 int IN_DIM_H = conv_p.IN_DIM[0];
135 int IN_DIM_W = conv_p.IN_DIM[1];
136 int K_H = conv_p.K[0];
137 int K_W = conv_p.K[1];
138 constexpr int STRIDE_H = 2;
139 constexpr int STRIDE_W = 2;
140 int PAD_H = conv_p.pad[0];
141 int PAD_W = conv_p.pad[1];
142 int OUT_DIM_H = conv_p.OUT_DIM[0];
143 int OUT_DIM_W = conv_p.OUT_DIM[1];
144 int OUT_DIM_HW = OUT_DIM_H * OUT_DIM_W;
145
146 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
147 int n = i / OUT_DIM_HW;
148 int hw = i % OUT_DIM_HW;
149 int w = hw % OUT_DIM_W;
150 int h = hw / OUT_DIM_W;
151
152 // j refers to column index within block
153 int j = 0;
154 // r and s iterate over K_H and K_W, respectively
155 for (int r = 0; r < K_H; ++r) {
156 int h_in = -PAD_H + h * STRIDE_H + r;
157 if (h_in < 0 || h_in >= IN_DIM_H) {
158 // Short-circuit if h_in is in padding.
159 std::memset(
160 out + (i - block.row_start) * BCOL + j,
161 a_zero_pt,
162 sizeof(uint8_t) * K_W * IC);
163 j += K_W * IC;
164 continue;
165 }
166
167 int s = 0;
168 // left_pad_len : the number of spatial pixels we need to pad at the
169 // beginning
170 int left_pad_len = PAD_W - w * STRIDE_W;
171 if (left_pad_len > 0) {
172 std::memset(
173 out + (i - block.row_start) * BCOL + j,
174 a_zero_pt,
175 sizeof(uint8_t) * left_pad_len * IC);
176 s += left_pad_len;
177 }
178
179 // mid_len : the number of spatial pixels that we handle normally
180 // (no padding)
181 int mid_len = std::min(IN_DIM_W + PAD_W - w * STRIDE_W, K_W) - s;
182 std::memcpy(
183 out + (i - block.row_start) * BCOL + j + s * IC,
184 sdata +
185 ((n * IN_DIM_H + h_in) * IN_DIM_W + -PAD_W + w * STRIDE_W + s) *
186 IC,
187 sizeof(uint8_t) * mid_len * IC);
188 s += mid_len;
189
190 // right_pad_len : the number of spatial pixels we need to pad at the end
191 int right_pad_len = K_W - s;
192 if (right_pad_len > 0) {
193 std::memset(
194 out + (i - block.row_start) * BCOL + j + s * IC,
195 a_zero_pt,
196 sizeof(uint8_t) * right_pad_len * IC);
197 }
198 j += K_W * IC;
199 } // r loop
200
201 // zero fill
202 // Please see the comment in PackAMatrix.cc for zero vs zero_pt fill.
203 if (COL_P_SIZE - COL_SIZE > 0) {
204 std::memset(
205 &out[(i - block.row_start) * BCOL + COL_SIZE],
206 0,
207 sizeof(uint8_t) * COL_P_SIZE - COL_SIZE);
208 }
209
210 if (row_offset_buf) {
211 int32_t row_sum =
212 row_offset_acc ? row_offset_buf[i - block.row_start] : 0;
213 row_sum += reduceAvx2(out + (i - block.row_start) * BCOL, COL_SIZE);
214 row_offset_buf[i - block.row_start] = row_sum;
215 }
216 }
217}
218
219template <typename T, typename accT, int SPATIAL_DIM>
220void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) {
221 block_type_t block_p = {
222 block.row_start,
223 block.row_size,
224 block.col_start,
225 (block.col_size + row_interleave_B_ - 1) / row_interleave_B_ *
226 row_interleave_B_};
227 BaseType::packedBlock(block_p);
228 T* out = BaseType::getBuf();
229 // accumulate into row offset?
230 bool row_offset_acc =
231 (block.col_start % (this->numCols() / this->numGroups())) != 0;
232 int32_t* row_offset_buf = getRowOffsetBuffer();
233
234 bool point_wise = true;
235 for (int d = 0; d < SPATIAL_DIM; ++d) {
236 if (conv_p_.K[d] != 1 || conv_p_.pad[d] != 0 || conv_p_.stride[d] != 1 ||
237 conv_p_.dilation[d] != 1) {
238 point_wise = false;
239 break;
240 }
241 }
242 for (int d = SPATIAL_DIM; d < SPATIAL_DIM * 2; ++d) {
243 if (conv_p_.pad[d] != 0) {
244 point_wise = false;
245 break;
246 }
247 }
248
249 // reduceAvx2 only written for T == uint8_t
250 static_assert(
251 std::is_same<T, uint8_t>::value,
252 "PackAWithIm2Col<T, accT>::pack only works for T == uint8_t");
253 if (point_wise) {
254 int32_t ld = this->numCols();
255 if (row_offset_buf) {
256 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
257 int buf_idx = i - block.row_start;
258 memcpy(
259 out + buf_idx * BaseType::blockColSize(),
260 sdata_ + i * ld + block.col_start,
261 block.col_size * sizeof(T));
262 // zero fill
263 for (int j = block.col_size; j < block_p.col_size; ++j) {
264 out[buf_idx * BaseType::blockColSize() + j] = 0;
265 }
266 int32_t row_sum =
267 row_offset_acc ? row_offset_buf[i - block.row_start] : 0;
268 row_sum +=
269 reduceAvx2(sdata_ + i * ld + block.col_start, block.col_size);
270 row_offset_buf[i - block.row_start] = row_sum;
271 }
272 } else {
273 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
274 int buf_idx = i - block.row_start;
275 memcpy(
276 out + buf_idx * BaseType::blockColSize(),
277 sdata_ + i * ld + block.col_start,
278 block.col_size * sizeof(T));
279 // zero fill
280 for (int j = block.col_size; j < block_p.col_size; ++j) {
281 out[buf_idx * BaseType::blockColSize() + j] = 0;
282 }
283 }
284 }
285
286 return;
287 }
288
289 int ic_per_group = conv_p_.IC / conv_p_.G;
290
291 if (!conv_p_.transposed && SPATIAL_DIM == 2 && conv_p_.IC == 3 &&
292 conv_p_.G == 1 && conv_p_.stride[0] == 2 && conv_p_.stride[1] == 2 &&
293 block.col_start == 0 && conv_p_.pad[0] == ((conv_p_.K[0] - 1) / 2) &&
294 conv_p_.pad[1] == ((conv_p_.K[1] - 1) / 2) &&
295 block_p.col_size <= BaseType::blockColSize() &&
296 conv_p_.dilation[0] == 1 && conv_p_.dilation[1] == 1 &&
297 std::is_same<T, uint8_t>::value) {
298 if (BaseType::blockColSize() == 256) {
299 pack_a_with_im2col_opt<SPATIAL_DIM, 256>(
300 conv_p_,
301 block,
302 reinterpret_cast<const uint8_t*>(sdata_),
303 reinterpret_cast<uint8_t*>(out),
304 a_zero_pt_,
305 row_offset_buf,
306 block.col_size,
307 block_p.col_size,
308 row_offset_acc);
309 return;
310 } else if (BaseType::blockColSize() == 512) {
311 pack_a_with_im2col_opt<SPATIAL_DIM, 512>(
312 conv_p_,
313 block,
314 reinterpret_cast<const uint8_t*>(sdata_),
315 reinterpret_cast<uint8_t*>(out),
316 a_zero_pt_,
317 row_offset_buf,
318 block.col_size,
319 block_p.col_size,
320 row_offset_acc);
321 return;
322 }
323 }
324 if (conv_p_.transposed) {
325 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
326 if (SPATIAL_DIM == 1) { // static if
327 int n = i / (conv_p_.OUT_DIM[0]);
328 int ow = i % (conv_p_.OUT_DIM[0]);
329 for (int j = block.col_start;
330 j < block.col_start + block.col_size + ic_per_group - 1;
331 j += ic_per_group) {
332 int j_blk_id = j / ic_per_group;
333 // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC )
334 int j_blk_start = std::max(j_blk_id * ic_per_group, block.col_start);
335 int j_blk_end = std::min(
336 (j_blk_id + 1) * ic_per_group, block.col_start + block.col_size);
337 if (j_blk_start >= j_blk_end) {
338 break;
339 }
340
341 int grs = j / ic_per_group;
342 int s = grs % conv_p_.K[0];
343 int g = grs / conv_p_.K[0];
344
345 int w = ow + conv_p_.pad[0] - s * conv_p_.dilation[0];
346 int w_in = w / conv_p_.stride[0];
347 if (w_in * conv_p_.stride[0] == w && w_in >= 0 &&
348 w_in < conv_p_.IN_DIM[0]) {
349 std::memcpy(
350 out + (i - block.row_start) * BaseType::blockColSize() +
351 j_blk_start - block.col_start,
352 sdata_ + (n * conv_p_.IN_DIM[0] + w_in) * conv_p_.IC +
353 g * ic_per_group + (j_blk_start % ic_per_group),
354 sizeof(T) * (j_blk_end - j_blk_start));
355 } else {
356 // Please note that padding for convolution should be filled with
357 // zero_pt
358 std::memset(
359 out + (i - block.row_start) * BaseType::blockColSize() +
360 (j_blk_start - block.col_start),
361 a_zero_pt_,
362 sizeof(T) * (j_blk_end - j_blk_start));
363 }
364 }
365
366 } else if (SPATIAL_DIM == 2) { // static if
367 int n = i / (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1]);
368 int hw = i % (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1]);
369 int ow = hw % conv_p_.OUT_DIM[1];
370 int oh = hw / conv_p_.OUT_DIM[1];
371 for (int j = block.col_start;
372 j < block.col_start + block.col_size + ic_per_group - 1;
373 j += ic_per_group) {
374 int j_blk_id = j / ic_per_group;
375 // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC )
376 int j_blk_start = std::max(j_blk_id * ic_per_group, block.col_start);
377 int j_blk_end = std::min(
378 (j_blk_id + 1) * ic_per_group, block.col_start + block.col_size);
379 if (j_blk_start >= j_blk_end) {
380 break;
381 }
382
383 int grs = j / ic_per_group;
384 int s = grs % conv_p_.K[1];
385 int r = grs / conv_p_.K[1] % conv_p_.K[0];
386 int g = grs / conv_p_.K[1] / conv_p_.K[0];
387
388 int h = oh + conv_p_.pad[0] - r * conv_p_.dilation[0];
389 int w = ow + conv_p_.pad[1] - s * conv_p_.dilation[1];
390
391 int h_in = h / conv_p_.stride[0];
392 int w_in = w / conv_p_.stride[1];
393
394 if (h_in * conv_p_.stride[0] == h && h_in >= 0 &&
395 h_in < conv_p_.IN_DIM[0] && w_in * conv_p_.stride[1] == w &&
396 w_in >= 0 && w_in < conv_p_.IN_DIM[1]) {
397 std::memcpy(
398 out + (i - block.row_start) * BaseType::blockColSize() +
399 j_blk_start - block.col_start,
400 sdata_ +
401 ((n * conv_p_.IN_DIM[0] + h_in) * conv_p_.IN_DIM[1] +
402 w_in) *
403 conv_p_.IC +
404 g * ic_per_group + (j_blk_start % ic_per_group),
405 sizeof(T) * (j_blk_end - j_blk_start));
406 } else {
407 // Please note that padding for convolution should be filled with
408 // zero_pt
409 std::memset(
410 out + (i - block.row_start) * BaseType::blockColSize() +
411 (j_blk_start - block.col_start),
412 a_zero_pt_,
413 sizeof(T) * (j_blk_end - j_blk_start));
414 }
415 }
416 } else if (SPATIAL_DIM == 3) { // static if
417 int n =
418 i / (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1] * conv_p_.OUT_DIM[2]);
419 int thw =
420 i % (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1] * conv_p_.OUT_DIM[2]);
421 int ow = thw % conv_p_.OUT_DIM[2];
422 int oh = thw / conv_p_.OUT_DIM[2] % conv_p_.OUT_DIM[1];
423 int ot = thw / conv_p_.OUT_DIM[2] / conv_p_.OUT_DIM[1];
424 for (int j = block.col_start;
425 j < block.col_start + block.col_size + ic_per_group - 1;
426 j += ic_per_group) {
427 int j_blk_id = j / ic_per_group;
428 // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC )
429 int j_blk_start = std::max(j_blk_id * ic_per_group, block.col_start);
430 int j_blk_end = std::min(
431 (j_blk_id + 1) * ic_per_group, block.col_start + block.col_size);
432 if (j_blk_start >= j_blk_end) {
433 break;
434 }
435
436 int gqrs = j / ic_per_group;
437 int s = gqrs % conv_p_.K[2];
438 int r = gqrs / conv_p_.K[2] % conv_p_.K[1];
439 int q = gqrs / conv_p_.K[2] / conv_p_.K[1] % conv_p_.K[0];
440 int g = gqrs / conv_p_.K[2] / conv_p_.K[1] / conv_p_.K[0];
441
442 int t = ot + conv_p_.pad[0] - q * conv_p_.dilation[0];
443 int h = oh + conv_p_.pad[1] - r * conv_p_.dilation[1];
444 int w = ow + conv_p_.pad[2] - s * conv_p_.dilation[2];
445 int t_in = t / conv_p_.stride[0];
446 int h_in = h / conv_p_.stride[1];
447 int w_in = w / conv_p_.stride[2];
448
449 if (t_in * conv_p_.stride[0] == t && t_in >= 0 &&
450 t_in < conv_p_.IN_DIM[0] && h_in * conv_p_.stride[1] == h &&
451 h_in >= 0 && h_in < conv_p_.IN_DIM[1] &&
452 w_in * conv_p_.stride[2] == w && w_in >= 0 &&
453 w_in < conv_p_.IN_DIM[2]) {
454 std::memcpy(
455 out + (i - block.row_start) * BaseType::blockColSize() +
456 j_blk_start - block.col_start,
457 sdata_ +
458 (((n * conv_p_.IN_DIM[0] + t_in) * conv_p_.IN_DIM[1] +
459 h_in) *
460 conv_p_.IN_DIM[2] +
461 w_in) *
462 conv_p_.IC +
463 g * ic_per_group + (j_blk_start % ic_per_group),
464 sizeof(T) * (j_blk_end - j_blk_start));
465 } else {
466 // Please note that padding for convolution should be filled with
467 // zero_pt
468 std::memset(
469 &out
470 [(i - block.row_start) * BaseType::blockColSize() +
471 (j_blk_start - block.col_start)],
472 a_zero_pt_,
473 sizeof(T) * (j_blk_end - j_blk_start));
474 }
475 }
476 }
477
478 // zero fill
479 // Please see the comment in PackAMatrix.cc for zero vs zero_pt fill.
480 if ((block_p.col_start + block_p.col_size) -
481 (block.col_start + block.col_size) >
482 0) {
483 std::memset(
484 &out
485 [(i - block.row_start) * BaseType::blockColSize() +
486 (block.col_size)],
487 0,
488 sizeof(T) *
489 ((block_p.col_start + block_p.col_size) -
490 (block.col_start + block.col_size)));
491 }
492
493 if (row_offset_buf) {
494 int32_t row_sum =
495 row_offset_acc ? row_offset_buf[i - block.row_start] : 0;
496 row_sum += reduceAvx2(
497 out + (i - block.row_start) * this->blockColSize(), block.col_size);
498 row_offset_buf[i - block.row_start] = row_sum;
499 }
500 } // for each i
501 } else {
502 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
503 if (SPATIAL_DIM == 1) { // static if
504 int n = i / (conv_p_.OUT_DIM[0]);
505 int w = i % (conv_p_.OUT_DIM[0]);
506 for (int j = block.col_start;
507 j < block.col_start + block.col_size + ic_per_group - 1;
508 j += ic_per_group) {
509 int j_blk_id = j / ic_per_group;
510 // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC )
511 int j_blk_start = std::max(j_blk_id * ic_per_group, block.col_start);
512 int j_blk_end = std::min(
513 (j_blk_id + 1) * ic_per_group, block.col_start + block.col_size);
514 if (j_blk_start >= j_blk_end) {
515 break;
516 }
517
518 int grs = j / ic_per_group;
519 int s = grs % conv_p_.K[0];
520 int g = grs / conv_p_.K[0];
521
522 int w_in =
523 -conv_p_.pad[0] + w * conv_p_.stride[0] + s * conv_p_.dilation[0];
524 if (w_in < 0 || w_in >= conv_p_.IN_DIM[0]) {
525 // Please note that padding for convolution should be filled with
526 // zero_pt
527 std::memset(
528 out + (i - block.row_start) * BaseType::blockColSize() +
529 (j_blk_start - block.col_start),
530 a_zero_pt_,
531 sizeof(T) * (j_blk_end - j_blk_start));
532 } else {
533 std::memcpy(
534 out + (i - block.row_start) * BaseType::blockColSize() +
535 j_blk_start - block.col_start,
536 sdata_ + (n * conv_p_.IN_DIM[0] + w_in) * conv_p_.IC +
537 g * ic_per_group + (j_blk_start % ic_per_group),
538 sizeof(T) * (j_blk_end - j_blk_start));
539 }
540 }
541
542 } else if (SPATIAL_DIM == 2) { // static if
543 int n = i / (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1]);
544 int hw = i % (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1]);
545 int w = hw % conv_p_.OUT_DIM[1];
546 int h = hw / conv_p_.OUT_DIM[1];
547 for (int j = block.col_start;
548 j < block.col_start + block.col_size + ic_per_group - 1;
549 j += ic_per_group) {
550 int j_blk_id = j / ic_per_group;
551 // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC )
552 int j_blk_start = std::max(j_blk_id * ic_per_group, block.col_start);
553 int j_blk_end = std::min(
554 (j_blk_id + 1) * ic_per_group, block.col_start + block.col_size);
555 if (j_blk_start >= j_blk_end) {
556 break;
557 }
558
559 int grs = j / ic_per_group;
560 int s = grs % conv_p_.K[1];
561 int r = grs / conv_p_.K[1] % conv_p_.K[0];
562 int g = grs / conv_p_.K[1] / conv_p_.K[0];
563
564 int h_in =
565 -conv_p_.pad[0] + h * conv_p_.stride[0] + r * conv_p_.dilation[0];
566 int w_in =
567 -conv_p_.pad[1] + w * conv_p_.stride[1] + s * conv_p_.dilation[1];
568
569 if (h_in < 0 || h_in >= conv_p_.IN_DIM[0] || w_in < 0 ||
570 w_in >= conv_p_.IN_DIM[1]) {
571 // Please note that padding for convolution should be filled with
572 // zero_pt
573 std::memset(
574 out + (i - block.row_start) * BaseType::blockColSize() +
575 (j_blk_start - block.col_start),
576 a_zero_pt_,
577 sizeof(T) * (j_blk_end - j_blk_start));
578 } else {
579 int chn_start_idx = j_blk_start % ic_per_group;
580 int src_offset =
581 ((n * conv_p_.IN_DIM[0] + h_in) * conv_p_.IN_DIM[1] + w_in) *
582 conv_p_.IC +
583 g * ic_per_group + chn_start_idx;
584 // fast path
585 // Copy across pixels of input width if we can. We can only do this
586 // if the following conditions are met. 1) If the number of groups
587 // is 1. For number of groups > 1, im2col
588 // doesn't copy data across groups.
589 // 2) If dilation is 1. For dilation > 1, copying from input
590 // across channels is not sequential.
591 // 3) For copy from the last channel (end of filter or
592 // end of image width) for the current filter,
593 // only copy if we have enough in the current channel.
594 //
595 if (conv_p_.G == 1 && conv_p_.dilation[1] == 1 &&
596 ((s < (conv_p_.K[1] - 1) && w_in < (conv_p_.IN_DIM[1] - 1)) ||
597 ((chn_start_idx + block.col_size) <= ic_per_group))) {
598 // left edge adjustment with s
599 j_blk_end = std::min(
600 (j_blk_id + conv_p_.K[1] - s) * ic_per_group,
601 block.col_start + block.col_size);
602 // right edge adjustment with w_in
603 j_blk_end = std::min(
604 (j_blk_id + conv_p_.IN_DIM[1] - w_in) * ic_per_group,
605 j_blk_end);
606 j += j_blk_end - j_blk_start - ic_per_group;
607 }
608 std::memcpy(
609 out + (i - block.row_start) * BaseType::blockColSize() +
610 j_blk_start - block.col_start,
611 sdata_ + src_offset,
612 sizeof(T) * (j_blk_end - j_blk_start));
613 }
614 }
615 } else if (SPATIAL_DIM == 3) { // static if
616 int n =
617 i / (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1] * conv_p_.OUT_DIM[2]);
618 int thw =
619 i % (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1] * conv_p_.OUT_DIM[2]);
620 int w = thw % conv_p_.OUT_DIM[2];
621 int h = thw / conv_p_.OUT_DIM[2] % conv_p_.OUT_DIM[1];
622 int t = thw / conv_p_.OUT_DIM[2] / conv_p_.OUT_DIM[1];
623 for (int j = block.col_start;
624 j < block.col_start + block.col_size + ic_per_group - 1;
625 j += ic_per_group) {
626 int j_blk_id = j / ic_per_group;
627 // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC )
628 int j_blk_start = std::max(j_blk_id * ic_per_group, block.col_start);
629 int j_blk_end = std::min(
630 (j_blk_id + 1) * ic_per_group, block.col_start + block.col_size);
631 if (j_blk_start >= j_blk_end) {
632 break;
633 }
634
635 int gqrs = j / ic_per_group;
636 int s = gqrs % conv_p_.K[2];
637 int r = gqrs / conv_p_.K[2] % conv_p_.K[1];
638 int q = gqrs / conv_p_.K[2] / conv_p_.K[1] % conv_p_.K[0];
639 int g = gqrs / conv_p_.K[2] / conv_p_.K[1] / conv_p_.K[0];
640
641 int t_in =
642 -conv_p_.pad[0] + t * conv_p_.stride[0] + q * conv_p_.dilation[0];
643 int h_in =
644 -conv_p_.pad[1] + h * conv_p_.stride[1] + r * conv_p_.dilation[1];
645 int w_in =
646 -conv_p_.pad[2] + w * conv_p_.stride[2] + s * conv_p_.dilation[2];
647
648 if (t_in < 0 || t_in >= conv_p_.IN_DIM[0] || h_in < 0 ||
649 h_in >= conv_p_.IN_DIM[1] || w_in < 0 ||
650 w_in >= conv_p_.IN_DIM[2]) {
651 // Please note that padding for convolution should be filled with
652 // zero_pt
653 std::memset(
654 &out
655 [(i - block.row_start) * BaseType::blockColSize() +
656 (j_blk_start - block.col_start)],
657 a_zero_pt_,
658 sizeof(T) * (j_blk_end - j_blk_start));
659 } else {
660 std::memcpy(
661 out + (i - block.row_start) * BaseType::blockColSize() +
662 j_blk_start - block.col_start,
663 sdata_ +
664 (((n * conv_p_.IN_DIM[0] + t_in) * conv_p_.IN_DIM[1] +
665 h_in) *
666 conv_p_.IN_DIM[2] +
667 w_in) *
668 conv_p_.IC +
669 g * ic_per_group + (j_blk_start % ic_per_group),
670 sizeof(T) * (j_blk_end - j_blk_start));
671 }
672 }
673 }
674
675 // zero fill
676 // Please see the comment in PackAMatrix.cc for zero vs zero_pt fill.
677 if ((block_p.col_start + block_p.col_size) -
678 (block.col_start + block.col_size) >
679 0) {
680 std::memset(
681 &out
682 [(i - block.row_start) * BaseType::blockColSize() +
683 (block.col_size)],
684 0,
685 sizeof(T) *
686 ((block_p.col_start + block_p.col_size) -
687 (block.col_start + block.col_size)));
688 }
689
690 if (row_offset_buf) {
691 int32_t row_sum =
692 row_offset_acc ? row_offset_buf[i - block.row_start] : 0;
693 row_sum += reduceAvx2(
694 out + (i - block.row_start) * this->blockColSize(), block.col_size);
695 row_offset_buf[i - block.row_start] = row_sum;
696 }
697 } // for each i
698 }
699}
700
701template <typename T, typename accT, int SPATIAL_DIM>
702void PackAWithIm2Col<T, accT, SPATIAL_DIM>::printPackedMatrix(
703 std::string name) {
704 std::cout << name << ":"
705 << "[" << BaseType::numPackedRows() << ", "
706 << BaseType::numPackedCols() << "]" << std::endl;
707
708 T* out = BaseType::getBuf();
709 for (auto r = 0; r < BaseType::numPackedRows(); ++r) {
710 for (auto c = 0; c < BaseType::numPackedCols(); ++c) {
711 T val = out[r * BaseType::blockColSize() + c];
712 if (std::is_integral<T>::value) {
713 // cast to int64 because cout doesn't print int8_t type directly
714 std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
715 } else {
716 std::cout << std::setw(5) << val << " ";
717 }
718 }
719 std::cout << std::endl;
720 }
721 std::cout << std::endl;
722}
723
724template <typename T, typename accT, int SPATIAL_DIM>
725int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize(
726 const BlockingFactors* params) {
727 if (cpuinfo_initialize()) {
728 if (params) {
729 return params->MCB;
730 } else {
731 if (fbgemmHasAvx512VnniSupport()) {
732 return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
733 } else if (fbgemmHasAvx512Support()) {
734 return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
735 } else if (fbgemmHasAvx2Support()) {
736 return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
737 } else {
738 // TODO: Have default slower path
739 assert(0 && "unsupported architecture");
740 return -1;
741 }
742 }
743 } else {
744 throw std::runtime_error("Failed to initialize cpuinfo!");
745 }
746}
747
748template class PackAWithIm2Col<uint8_t, int32_t, 1>;
749template class PackAWithIm2Col<uint8_t, int16_t, 1>;
750template class PackAWithIm2Col<uint8_t, int32_t, 2>;
751template class PackAWithIm2Col<uint8_t, int16_t, 2>;
752template class PackAWithIm2Col<uint8_t, int32_t, 3>;
753template class PackAWithIm2Col<uint8_t, int16_t, 3>;
754
755} // namespace fbgemm
756