1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "./RefImplementations.h"
9
10#include "fbgemm/FbgemmBuild.h"
11#include "fbgemm/FbgemmConvert.h"
12
13#include <algorithm>
14#include <cassert>
15#include <cmath>
16#include <cstring>
17#include <iostream>
18#include <numeric>
19#include <thread>
20
21using namespace std;
22
23namespace fbgemm {
24
25typedef union {
26 uint32_t I;
27 float F;
28} fint32;
29
30// Thread-safe random number generator
31//
32// Return a random 32bit integer using xoshiro128++
33// http://prng.di.unimi.it/xoshiro128plusplus.c
34inline uint32_t rnd128_next(int idx, int vlen) {
35 constexpr int VLEN_MAX = 16; // max vector size
36 alignas(64) static thread_local uint32_t g_rnd128_buffer[4 * VLEN_MAX];
37 static thread_local bool g_rnd128_initialized = false;
38
39 // Splitmix64: http://prng.di.unimi.it/splitmix64.c
40 auto rnd128_init_next = [](uint64_t& x) {
41 uint64_t z = (x += 0x9e3779b97f4a7c15);
42 z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9;
43 z = (z ^ (z >> 27)) * 0x94d049bb133111eb;
44 return z ^ (z >> 31);
45 };
46
47 auto rotl = [](const uint32_t x, int k) {
48 return (x << k) | (x >> (32 - k));
49 };
50
51 if (!g_rnd128_initialized) {
52 // Initialize rand buffer with uniq values per thread
53 uint64_t h0 = std::hash<std::thread::id>{}(std::this_thread::get_id());
54 for (auto i = 0; i < 4; ++i) {
55 // Use thread hash as seed
56 g_rnd128_buffer[i * VLEN_MAX] = rnd128_init_next(h0);
57 uint64_t h1 = g_rnd128_buffer[i * VLEN_MAX];
58 for (auto v = 1; v < VLEN_MAX; ++v) {
59 g_rnd128_buffer[i * VLEN_MAX + v] = rnd128_init_next(h1);
60 }
61 }
62 g_rnd128_initialized = true;
63 }
64
65 const uint32_t result =
66 rotl(g_rnd128_buffer[idx] + g_rnd128_buffer[3 * vlen + idx], 7) +
67 g_rnd128_buffer[idx];
68
69 const uint32_t t = g_rnd128_buffer[1 * vlen + idx] << 9;
70
71 g_rnd128_buffer[2 * vlen + idx] ^= g_rnd128_buffer[0 * vlen + idx];
72 g_rnd128_buffer[3 * vlen + idx] ^= g_rnd128_buffer[1 * vlen + idx];
73 g_rnd128_buffer[1 * vlen + idx] ^= g_rnd128_buffer[2 * vlen + idx];
74 g_rnd128_buffer[0 * vlen + idx] ^= g_rnd128_buffer[3 * vlen + idx];
75
76 g_rnd128_buffer[2 * vlen + idx] ^= t;
77
78 g_rnd128_buffer[3 * vlen + idx] = rotl(g_rnd128_buffer[3 * vlen + idx], 11);
79
80 return result;
81}
82
83void FloatToFloat16_ref(
84 const float* src,
85 float16* dst,
86 size_t size,
87 bool do_clip) {
88 constexpr float FP16_MAX = 65504.f;
89 if (do_clip) {
90 for (size_t i = 0; i < size; i++) {
91 float cur_src = std::max(-FP16_MAX, std::min(src[i], FP16_MAX));
92 dst[i] = cpu_float2half_rn(cur_src);
93 }
94 } else {
95 for (size_t i = 0; i < size; i++) {
96 dst[i] = cpu_float2half_rn(src[i]);
97 }
98 }
99}
100
101void Float16ToFloat_ref(const float16* src, float* dst, size_t size) {
102 for (size_t i = 0; i < size; i++) {
103 dst[i] = cpu_half2float(src[i]);
104 }
105}
106
107void FloatToBfloat16_ref(const float* src, bfloat16* dst, size_t size) {
108 for (size_t i = 0; i < size; i++) {
109 // Add 2^15 and right shift 16 to do round-nearest
110 dst[i] = (*reinterpret_cast<const uint32_t*>(src + i) + (1 << 15)) >> 16;
111 }
112}
113
114void Bfloat16ToFloat_ref(const bfloat16* src, float* dst, size_t size) {
115 for (size_t i = 0; i < size; i++) {
116 uint32_t val_fp32 =
117 static_cast<uint32_t>(reinterpret_cast<const uint16_t*>(src)[i]) << 16;
118 reinterpret_cast<uint32_t*>(dst)[i] = val_fp32;
119 }
120}
121
122void FloatToFloat8_ref(
123 const float input,
124 uint8_t* output,
125 int exponent_bits,
126 int exponent_bias) {
127 float max_pos = (1 << ((1 << exponent_bits) - 2 - exponent_bias)) *
128 (2 - std::pow(2, exponent_bits - 7));
129 int mantissa_bits = 7 - exponent_bits;
130 fint32 val_out, bouncer, smallest_normal;
131
132 val_out.F = input;
133 uint32_t sign_bit = val_out.I & 0x80000000;
134 val_out.I = val_out.I & 0x7FFFFFFF;
135 val_out.F = fminf(val_out.F, max_pos);
136
137 smallest_normal.I = (127 - exponent_bias + 1)
138 << 23; // smallest hfp8 normal number in FP32
139 // I don't know if the input "min_pos" is the smallest denormalized number
140 // or the smallest normalized number. The test below needs to be done with
141 // the smallest normal number, which is the numerical value 2^(1-bias)
142
143 // The conversion for denormalized values are slightly different. HFP8 is so
144 // low precision that gradual underflow is probably crucial
145 if (val_out.F >= smallest_normal.F) {
146 // Use round to nearest even. We make use of the standard rounding mechanism
147 // in FP32 rather than rounding the mantissa and handling tie-to-even and
148 // incrementing exponent We want to round of 23-mbits of the FP32 value
149 // val_in This can be done by adding a power of 2 exactly 23-mbits larger
150 // than the exponent of val_in This forces val_in to be moved to the right
151 // and rounding exact at the location corresponding to having mbits of
152 // explicit mantissa left
153 bouncer.I = (val_out.I & 0xFF800000) + ((23 - mantissa_bits) << 23);
154 val_out.F = (bouncer.F + val_out.F) - bouncer.F;
155 // adding the bouncer rounds off bits, and subtracting bouncer
156 // leaves the desired value, albeit in FP32 encoding
157 // All we need is to change the exponent encoding to using "bias"
158 val_out.I = uint32_t(val_out.I - ((127 - exponent_bias) << 23))
159 << (8 - exponent_bits);
160 val_out.I =
161 ((val_out.I | sign_bit) >>
162 24); // the 8 lsbs is the desired HFP8 encoding
163
164 } else {
165 // When the value is in the denormal range, IEEE numbers essentially becomes
166 // a fixed point number. The lsb is the smallest non-zero number
167 // 2^(1-bias-mbits) Hence, we define the bouncer so that its lsb is this
168 // smallest non-zero number Adding the input to this bouncer forces rounding
169 // to occur appropriately Also, in this situation, after adding the bouncer,
170 // the 8 least significant bits of the sum is already the HFP8 encoding of
171 // the desired result. Just need to restore the sign bit
172 bouncer.I = (127 + (23 + (1 - exponent_bias - mantissa_bits))) << 23;
173 val_out.F = bouncer.F + val_out.F;
174 val_out.I = val_out.I | (sign_bit >> 24);
175 }
176
177 *output = val_out.I; // get the 8 lsbs
178}
179
180void Float8ToFloat_ref(
181 const uint8_t input,
182 float* output,
183 int exponent_bits,
184 int exponent_bias) {
185 fint32 val_out, sign, multiplier;
186
187 sign.I = (input & 0x80) << 24;
188 val_out.I = (input & 0x7F) << (24 - (8 - exponent_bits));
189 // so that the mantissa bits start at the mantissa bit positions of FP32
190 // encoding
191
192 // Let the hfp8 mantissa bits correspond to the value frac, 0 <= frac < 1
193 // So if the hfp8 value is a normal number, it's value is 2^e x (1+frac)
194 // where e is its (true, unbiased) exponent
195 // If the hfp8 value is denormal, the value is 2^(1-bias) x frac
196
197 // However, the bit pattern in the 8-bit exponent field of val_out.F
198 // is bias+e when hfp8 is normal, and 0 when hfp8 is subnormal.
199 // So, as an FP32 value, when hfp8 is normal, val_out.F represents the value
200 // of 2^(bias+e-127) * (1+frac)
201 // And when hfp8 is subnormal, val_out.F is also subnormal, and represents the
202 // value of 2^(-126) * frac In either case, val_out.F corresponds to
203 // 2^(bias-127) * (value of hfp8 input) Thus, if we multiply val_out.F by
204 // 2^(127-bias), we obtain the hfp8 value as an FP32 number
205
206 multiplier.I = (127 + (127 - exponent_bias))
207 << 23; // multiplier.F is 2^(127-bias)
208 val_out.F *= multiplier.F;
209 val_out.I |= sign.I;
210 *output = val_out.F;
211}
212
213void requantize_u8acc32_ref(
214 int M,
215 int N,
216 int ld,
217 const int32_t* inp,
218 uint8_t* out,
219 int32_t C_multiplier,
220 int32_t C_right_shift,
221 int32_t C_zero_point,
222 int32_t A_zero_point,
223 int32_t B_zero_point,
224 const int32_t* row_offsets,
225 const int32_t* col_offsets,
226 const int32_t* bias,
227 bool fuse_relu) {
228 int64_t nudge = 1ll << std::max(0, C_right_shift - 1);
229 for (int i = 0; i < M; ++i) {
230 for (int j = 0; j < N; ++j) {
231 int32_t raw = inp[i * ld + j];
232 if (A_zero_point) {
233 raw -= A_zero_point * col_offsets[j];
234 }
235 if (B_zero_point) {
236 raw -= B_zero_point * row_offsets[i];
237 }
238 if (bias) {
239 raw += bias[j];
240 }
241
242 int64_t ab_64 =
243 static_cast<int64_t>(raw) * static_cast<int64_t>(C_multiplier);
244 int64_t rounded = ((ab_64 + nudge) >> C_right_shift) + C_zero_point;
245
246 out[i * ld + j] = std::max(
247 fuse_relu ? static_cast<int64_t>(C_zero_point) : 0l,
248 std::min(static_cast<int64_t>(255l), rounded));
249 }
250 }
251}
252
253void requantize_u8acc32_ref(
254 int M,
255 int N,
256 int ld,
257 const int32_t* inp,
258 uint8_t* out,
259 const float* C_multiplier,
260 int32_t C_zero_point,
261 int32_t A_zero_point,
262 const int32_t* B_zero_point,
263 const int32_t* row_offsets,
264 const int32_t* col_offsets,
265 const int32_t* bias,
266 int ncols_per_quant_group,
267 bool fuse_relu) {
268 for (int i = 0; i < M; ++i) {
269 for (int j = 0; j < N; ++j) {
270 int32_t raw = inp[i * ld + j];
271 if (A_zero_point) {
272 raw -= A_zero_point * col_offsets[j];
273 }
274 raw -= B_zero_point[j / ncols_per_quant_group] * row_offsets[i];
275 if (bias) {
276 raw += bias[j];
277 }
278
279 float result = raw * C_multiplier[j / ncols_per_quant_group];
280 long rounded = lrintf(result) + C_zero_point;
281 out[i * ld + j] = std::max(
282 fuse_relu ? static_cast<long>(C_zero_point) : 0l,
283 std::min(255l, rounded));
284 }
285 }
286}
287
288void matmul_u8i8acc32_ref(
289 int M,
290 int N,
291 int K,
292 int lda,
293 int ldb,
294 int ldc,
295 const uint8_t* Aint8,
296 const int8_t* Bint8,
297 int32_t* Cint32) {
298 for (int i = 0; i < M; ++i) {
299 for (int j = 0; j < N; ++j) {
300 int32_t sum = 0;
301 for (int k = 0; k < K; ++k) {
302 sum += static_cast<int32_t>(Aint8[i * lda + k]) *
303 static_cast<int32_t>(Bint8[k * ldb + j]);
304 }
305 Cint32[i * ldc + j] = sum;
306 }
307 }
308}
309
310void matmul_u8i8acc16_ref(
311 int M,
312 int N,
313 int K,
314 int lda,
315 int ldb,
316 int ldc,
317 int brow,
318 const uint8_t* Aint8,
319 const int8_t* Bint8,
320 int32_t* Cint32) {
321 for (int i = 0; i < M; ++i) {
322 for (int j = 0; j < N; ++j) {
323 int32_t sum = 0, sum_32bit = 0;
324 for (int k = 0; k < K; k += 2) {
325 int a0 = Aint8[i * lda + k];
326 int b0 = Bint8[k * ldb + j];
327 int a1 = 0, b1 = 0;
328 if (k + 1 < K) {
329 a1 = Aint8[i * lda + k + 1];
330 b1 = Bint8[(k + 1) * ldb + j];
331 }
332 sum = clip_16bit(sum + clip_16bit(a0 * b0 + a1 * b1));
333 if ((k % brow) == (brow - 2)) {
334 sum_32bit += sum;
335 sum = 0;
336 }
337 }
338 Cint32[i * ldc + j] = sum_32bit + sum;
339 }
340 }
341}
342
343void cblas_sgemm_ref(
344 const matrix_op_t transa,
345 const matrix_op_t transb,
346 const int m,
347 const int n,
348 const int k,
349 float alpha,
350 const float* Afp32,
351 int lda,
352 const float* Bfp32,
353 int ldb,
354 float beta,
355 float* Cfp32,
356 int ldc) {
357 for (int i = 0; i < m; ++i) {
358 for (int j = 0; j < n; ++j) {
359 float sum = 0;
360 for (int p = 0; p < k; ++p) {
361 float a =
362 (transa == matrix_op_t::NoTranspose ? Afp32[i * lda + p]
363 : Afp32[p * lda + i]);
364 float b =
365 (transb == matrix_op_t::NoTranspose ? Bfp32[p * ldb + j]
366 : Bfp32[j * ldb + p]);
367 sum += a * b;
368 }
369 if (beta == 0) {
370 Cfp32[i * ldc + j] = alpha * sum;
371 } else {
372 Cfp32[i * ldc + j] = alpha * sum + beta * Cfp32[i * ldc + j];
373 }
374 }
375 }
376}
377
378namespace {
379// From https://stackoverflow.com/questions/31652875
380uint64_t umul64wide(uint64_t a, uint64_t b) {
381 uint64_t a_lo = static_cast<uint32_t>(a);
382 uint64_t a_hi = a >> 32;
383 uint64_t b_lo = static_cast<uint32_t>(b);
384 uint64_t b_hi = b >> 32;
385
386 uint64_t p0 = a_lo * b_lo;
387 uint64_t p1 = a_lo * b_hi;
388 uint64_t p2 = a_hi * b_lo;
389
390 return p0 + (p1 << 32) + (p2 << 32);
391}
392} // namespace
393
394// Expected to have overflows
395NO_SANITIZE("undefined")
396void cblas_gemm_i64_i64acc_ref(
397 matrix_op_t transa,
398 matrix_op_t transb,
399 int M,
400 int N,
401 int K,
402 const int64_t* A,
403 int lda,
404 const int64_t* B,
405 int ldb,
406 bool accumulate,
407 int64_t* C,
408 int ldc) {
409 for (int i = 0; i < M; ++i) {
410 for (int j = 0; j < N; ++j) {
411 int64_t acc;
412 if (accumulate) {
413 acc = C[i * ldc + j];
414 } else {
415 acc = 0;
416 }
417 for (int k = 0; k < K; ++k) {
418 int64_t a =
419 A[transa == matrix_op_t::Transpose ? i + k * lda : i * lda + k];
420 int64_t b =
421 B[transb == matrix_op_t::Transpose ? k + j * ldb : k * ldb + j];
422 int64_t lo = umul64wide(a, b);
423 acc += lo;
424 }
425 C[i * ldc + j] = acc;
426 } // j
427 } // i
428}
429
430void row_offsets_u8acc32_ref(
431 int M,
432 int K,
433 int ld,
434 const uint8_t* Aint8,
435 int32_t* row_offsets) {
436 // row offset
437 for (int i = 0; i < M; ++i) {
438 int32_t sum = 0;
439 for (int k = 0; k < K; ++k) {
440 sum += static_cast<int32_t>(Aint8[i * ld + k]);
441 }
442 row_offsets[i] = sum;
443 }
444}
445
446void col_offsets_with_zero_pt_s8acc32_ref(
447 int K,
448 int N,
449 int ld,
450 const int8_t* Bint8,
451 const int32_t* B_zero_point,
452 int32_t* col_offsets,
453 int ncols_per_quant_group) {
454 for (int j = 0; j < N; ++j) {
455 int32_t sum = 0;
456 for (int k = 0; k < K; ++k) {
457 sum += Bint8[k * ld + j];
458 }
459 col_offsets[j] = sum - B_zero_point[j / ncols_per_quant_group] * K;
460 }
461}
462
463void spmdm_ref(
464 int M,
465 const uint8_t* A,
466 int lda,
467 fbgemm::CompressedSparseColumn& B,
468 bool accumulation,
469 int32_t* C,
470 int ldc,
471 int groups /*=1*/) {
472 int N = B.NumOfCols();
473 assert(N % groups == 0);
474 if (!accumulation) {
475 for (int i = 0; i < M; ++i) {
476 for (int j = 0; j < N; ++j) {
477 C[i * ldc + j] = 0;
478 }
479 }
480 }
481 for (int g = 0; g < groups; ++g) {
482 for (int j = g * (N / groups); j < (g + 1) * (N / groups); ++j) {
483 for (int k = B.ColPtr()[j]; k < B.ColPtr()[j + 1]; ++k) {
484 int row = g * B.NumOfRows() + B.RowIdx()[k];
485 int w = B.Values()[k];
486 for (int i = 0; i < M; ++i) {
487 C[i * ldc + j] += A[i * lda + row] * w;
488 }
489 }
490 } // for each column of B
491 } // for each group
492}
493
494int32_t clip_16bit(int32_t x) {
495 if (x > numeric_limits<int16_t>::max()) {
496 return std::min<int>(numeric_limits<int16_t>::max(), x);
497 } else if (x < numeric_limits<int16_t>::min()) {
498 return std::max<int>(numeric_limits<int16_t>::min(), x);
499 } else {
500 return x;
501 }
502}
503
504/* Imitate the Im2Col<float, CPUContext, StorageOrder::NWC> function
505 * from caffe2/utils/math_cpu.cc
506 * NWC StorageOrder/Layout
507 * A: NWC: NW_0 x C_0
508 * Ao: NWC: NW_1 x G KW C_0/G
509 */
510template <>
511FBGEMM_API void im2col_ref(
512 const conv_param_t<1>& conv_p,
513 const uint8_t* A,
514 int32_t A_zero_point,
515 uint8_t* Ao) {
516 int IC = conv_p.IC;
517 int G = conv_p.G;
518 assert(IC % G == 0);
519 array<int, 1> IN_DIM = conv_p.IN_DIM;
520 array<int, 1> OUT_DIM = conv_p.OUT_DIM;
521 array<int, 1> K = conv_p.K;
522
523 if (conv_p.transposed) {
524 for (int n = 0; n < conv_p.MB; ++n) {
525 for (int ow = 0; ow < OUT_DIM[0]; ++ow) {
526 for (int s = 0; s < K[0]; ++s) {
527 int w = ow + conv_p.pad[0] - s * conv_p.dilation[0];
528 int w_in = w / conv_p.stride[0];
529 if (w_in * conv_p.stride[0] == w && w_in >= 0 && w_in < IN_DIM[0]) {
530 for (int g = 0; g < G; ++g) {
531 memcpy(
532 Ao + (((n * OUT_DIM[0] + ow) * G + g) * K[0] + s) * (IC / G),
533 A + (n * IN_DIM[0] + w_in) * IC + g * (IC / G),
534 sizeof(uint8_t) * (IC / G));
535 }
536 } else {
537 for (int g = 0; g < G; ++g) {
538 memset(
539 Ao + (((n * OUT_DIM[0] + ow) * G + g) * K[0] + s) * (IC / G),
540 A_zero_point,
541 sizeof(uint8_t) * (IC / G));
542 }
543 }
544 } // for each s
545 } // for each ow
546 } // for each n
547 } else {
548 for (int n = 0; n < conv_p.MB; ++n) {
549 for (int w = 0; w < OUT_DIM[0]; ++w) {
550 for (int s = 0; s < K[0]; ++s) {
551 int w_in =
552 -conv_p.pad[0] + w * conv_p.stride[0] + s * conv_p.dilation[0];
553 if (w_in < 0 || w_in >= IN_DIM[0]) {
554 for (int g = 0; g < G; ++g) {
555 memset(
556 Ao + (((n * OUT_DIM[0] + w) * G + g) * K[0] + s) * (IC / G),
557 A_zero_point,
558 sizeof(uint8_t) * (IC / G));
559 }
560 } else {
561 for (int g = 0; g < G; ++g) {
562 memcpy(
563 Ao + (((n * OUT_DIM[0] + w) * G + g) * K[0] + s) * (IC / G),
564 A + (n * IN_DIM[0] + w_in) * IC + g * (IC / G),
565 sizeof(uint8_t) * (IC / G));
566 }
567 }
568 } // for each s
569 } // for each w
570 } // for each n
571 }
572}
573
574/* Imitate the Im2Col<float, CPUContext, StorageOrder::NHWC> function
575 * from caffe2/utils/math_cpu.cc
576 * NHWC StorageOrder/Layout
577 * A: NHWC: NH_0W_0 x C_0
578 * Ao: NHWC: NH_1W_1 x G RS C_0/G
579 */
580template <>
581FBGEMM_API void im2col_ref(
582 const conv_param_t<2>& conv_p,
583 const uint8_t* A,
584 int32_t A_zero_point,
585 uint8_t* Ao) {
586 int IC = conv_p.IC;
587 int G = conv_p.G;
588 assert(IC % G == 0);
589 array<int, 2> IN_DIM = conv_p.IN_DIM;
590 array<int, 2> OUT_DIM = conv_p.OUT_DIM;
591 array<int, 2> K = conv_p.K;
592
593 if (conv_p.transposed) {
594 for (int n = 0; n < conv_p.MB; ++n) {
595 for (int oh = 0; oh < OUT_DIM[0]; ++oh) {
596 for (int ow = 0; ow < OUT_DIM[1]; ++ow) {
597 for (int r = 0; r < K[0]; ++r) {
598 for (int s = 0; s < K[1]; ++s) {
599 int h = oh + conv_p.pad[0] - r * conv_p.dilation[0];
600 int w = ow + conv_p.pad[1] - s * conv_p.dilation[1];
601 int h_in = h / conv_p.stride[0];
602 int w_in = w / conv_p.stride[1];
603 if (h_in * conv_p.stride[0] == h && h_in >= 0 &&
604 h_in < IN_DIM[0] && w_in * conv_p.stride[1] == w &&
605 w_in >= 0 && w_in < IN_DIM[1]) {
606 for (int g = 0; g < G; ++g) {
607 memcpy(
608 Ao +
609 (((((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * G +
610 g) *
611 K[0] +
612 r) *
613 K[1] +
614 s) *
615 (IC / G),
616 A + ((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC +
617 g * (IC / G),
618 sizeof(uint8_t) * (IC / G));
619 }
620 } else {
621 for (int g = 0; g < G; ++g) {
622 memset(
623 Ao +
624 (((((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * G +
625 g) *
626 K[0] +
627 r) *
628 K[1] +
629 s) *
630 (IC / G),
631 A_zero_point,
632 sizeof(uint8_t) * (IC / G));
633 }
634 }
635 } // for each s
636 } // for each r
637 } // for each ow
638 } // for each oh
639 } // for each n
640 } else {
641 for (int n = 0; n < conv_p.MB; ++n) {
642 for (int h = 0; h < OUT_DIM[0]; ++h) {
643 for (int w = 0; w < OUT_DIM[1]; ++w) {
644 for (int r = 0; r < K[0]; ++r) {
645 int h_in =
646 -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0];
647 for (int s = 0; s < K[1]; ++s) {
648 int w_in = -conv_p.pad[1] + w * conv_p.stride[1] +
649 s * conv_p.dilation[1];
650 if (h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 ||
651 w_in >= IN_DIM[1]) {
652 for (int g = 0; g < G; ++g) {
653 memset(
654 Ao +
655 (((((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * G + g) *
656 K[0] +
657 r) *
658 K[1] +
659 s) *
660 (IC / G),
661 A_zero_point,
662 sizeof(uint8_t) * (IC / G));
663 }
664 } else {
665 for (int g = 0; g < G; ++g) {
666 memcpy(
667 Ao +
668 (((((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * G + g) *
669 K[0] +
670 r) *
671 K[1] +
672 s) *
673 (IC / G),
674 A + ((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC +
675 g * (IC / G),
676 sizeof(uint8_t) * (IC / G));
677 }
678 }
679 } // for each s
680 } // for each r
681 } // for each w
682 } // for each h
683 } // for each n
684 }
685}
686
687/* Imitate the Im2Col<float, CPUContext, StorageOrder::NHWC> function
688 * from caffe2/utils/math_cpu.cc
689 * NHWC StorageOrder/Layout
690 * A: NHWC: NT_0H_0W_0 x C_0
691 * Ao: NHWC: NT_1H_1W_1 x G QRS C_0/G
692 */
693template <>
694FBGEMM_API void im2col_ref(
695 const conv_param_t<3>& conv_p,
696 const uint8_t* A,
697 int32_t A_zero_point,
698 uint8_t* Ao) {
699 int IC = conv_p.IC;
700 int G = conv_p.G;
701 assert(IC % G == 0);
702 array<int, 3> IN_DIM = conv_p.IN_DIM;
703 array<int, 3> OUT_DIM = conv_p.OUT_DIM;
704 array<int, 3> K = conv_p.K;
705
706 if (conv_p.transposed) {
707 for (int n = 0; n < conv_p.MB; ++n) {
708 for (int ot = 0; ot < OUT_DIM[0]; ++ot) {
709 for (int oh = 0; oh < OUT_DIM[1]; ++oh) {
710 for (int ow = 0; ow < OUT_DIM[2]; ++ow) {
711 for (int q = 0; q < K[0]; ++q) {
712 for (int r = 0; r < K[1]; ++r) {
713 for (int s = 0; s < K[2]; ++s) {
714 int t = ot + conv_p.pad[0] - q * conv_p.dilation[0];
715 int h = oh + conv_p.pad[1] - r * conv_p.dilation[1];
716 int w = ow + conv_p.pad[2] - s * conv_p.dilation[2];
717 int t_in = t / conv_p.stride[0];
718 int h_in = h / conv_p.stride[1];
719 int w_in = w / conv_p.stride[2];
720 if (t_in * conv_p.stride[0] == t && t_in >= 0 &&
721 t_in < IN_DIM[0] && h_in * conv_p.stride[1] == h &&
722 h_in >= 0 && h_in < IN_DIM[1] &&
723 w_in * conv_p.stride[2] == w && w_in >= 0 &&
724 w_in < IN_DIM[2]) {
725 for (int g = 0; g < G; ++g) {
726 memcpy(
727 Ao +
728 (((((((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) *
729 OUT_DIM[2] +
730 ow) *
731 G +
732 g) *
733 K[0] +
734 q) *
735 K[1] +
736 r) *
737 K[2] +
738 s) *
739 (IC / G),
740 A +
741 (((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) *
742 IN_DIM[2] +
743 w_in) *
744 IC +
745 g * (IC / G),
746 sizeof(uint8_t) * (IC / G));
747 }
748 } else {
749 for (int g = 0; g < G; ++g) {
750 memset(
751 Ao +
752 (((((((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) *
753 OUT_DIM[2] +
754 ow) *
755 G +
756 g) *
757 K[0] +
758 q) *
759 K[1] +
760 r) *
761 K[2] +
762 s) *
763 (IC / G),
764 A_zero_point,
765 sizeof(uint8_t) * (IC / G));
766 }
767 }
768 } // for each s
769 } // for each r
770 } // for each q
771 } // for each ow
772 } // for each oh
773 } // for each ot
774 } // for each n
775 } else {
776 for (int n = 0; n < conv_p.MB; ++n) {
777 for (int t = 0; t < OUT_DIM[0]; ++t) {
778 for (int h = 0; h < OUT_DIM[1]; ++h) {
779 for (int w = 0; w < OUT_DIM[2]; ++w) {
780 for (int q = 0; q < K[0]; ++q) {
781 int t_in = -conv_p.pad[0] + t * conv_p.stride[0] +
782 q * conv_p.dilation[0];
783 for (int r = 0; r < K[1]; ++r) {
784 int h_in = -conv_p.pad[1] + h * conv_p.stride[1] +
785 r * conv_p.dilation[1];
786 for (int s = 0; s < K[2]; ++s) {
787 int w_in = -conv_p.pad[2] + w * conv_p.stride[2] +
788 s * conv_p.dilation[2];
789 if (t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 ||
790 h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2]) {
791 for (int g = 0; g < G; ++g) {
792 memset(
793 Ao +
794 (((((((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) *
795 OUT_DIM[2] +
796 w) *
797 G +
798 g) *
799 K[0] +
800 q) *
801 K[1] +
802 r) *
803 K[2] +
804 s) *
805 (IC / G),
806 A_zero_point,
807 sizeof(uint8_t) * (IC / G));
808 }
809 } else {
810 for (int g = 0; g < G; ++g) {
811 memcpy(
812 Ao +
813 (((((((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) *
814 OUT_DIM[2] +
815 w) *
816 G +
817 g) *
818 K[0] +
819 q) *
820 K[1] +
821 r) *
822 K[2] +
823 s) *
824 (IC / G),
825 A +
826 (((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) *
827 IN_DIM[2] +
828 w_in) *
829 IC +
830 g * (IC / G),
831 sizeof(uint8_t) * (IC / G));
832 }
833 }
834 } // for each s
835 } // for each r
836 } // for each q
837 } // for each w
838 } // for each h
839 } // for each t
840 } // for each n
841 }
842}
843
844// 1D Conv
845template <>
846FBGEMM_API void conv_ref(
847 const conv_param_t<1>& conv_p,
848 const uint8_t* A,
849 int32_t A_zero_point,
850 const int8_t* B,
851 int32_t* C) {
852 // A is assumed to be (N Lin Cin)
853 // B is assumed to be (G K Cin/G Cout/G)
854 // C is assumed to be (N Lout Cout)
855 int IC = conv_p.IC;
856 int OC = conv_p.OC;
857 int G = conv_p.G;
858 assert(IC % G == 0);
859 assert(OC % G == 0);
860 array<int, 1> IN_DIM = conv_p.IN_DIM;
861 array<int, 1> OUT_DIM = conv_p.OUT_DIM;
862 array<int, 1> K = conv_p.K;
863
864 if (conv_p.transposed) {
865 // for ref implementation, there is no padding on the input buffer,
866 // padding specifies how much we remove from the output buffers
867 for (int n = 0; n < conv_p.MB; ++n) {
868 for (int ow = 0; ow < OUT_DIM[0]; ++ow) {
869 // stride on output is fractional stride on input
870 // conv index is
871 // int w_in = -conv_p.pad[0] + w* conv_p.stride[0] + r*
872 // conv_p.dilation[0];
873 // so we reverse it
874 for (int g = 0; g < G; ++g) {
875 for (int oc = 0; oc < OC / G; ++oc) {
876 int sum = 0;
877 for (int r = 0; r < K[0]; ++r) {
878 int w = ow + conv_p.pad[0] - r * conv_p.dilation[0];
879 int w_in = w / conv_p.stride[0];
880 for (int ic = 0; ic < IC / G; ++ic) {
881 int a = (w_in * conv_p.stride[0] == w && w_in >= 0 &&
882 w_in < IN_DIM[0])
883 ? A[(n * IN_DIM[0] + w_in) * IC + g * (IC / G) + ic]
884 : A_zero_point;
885 int b =
886 B[((g * K[0] + r) * IC / G + ic) * (OC / G) +
887 oc]; // G K IC/G OC/G after transpose
888 sum += a * b;
889 } // for each ic
890 } // for each r
891 C[(n * OUT_DIM[0] + ow) * OC + g * (OC / G) + oc] = sum;
892 } // for each oc
893 } // for each g
894 } // for each w
895 } // for each n
896 } else {
897 for (int n = 0; n < conv_p.MB; ++n) {
898 for (int w = 0; w < OUT_DIM[0]; ++w) {
899 for (int g = 0; g < G; ++g) {
900 for (int m = 0; m < OC / G; ++m) {
901 int sum = 0;
902 for (int r = 0; r < K[0]; ++r) {
903 int w_in = -conv_p.pad[0] + w * conv_p.stride[0] +
904 r * conv_p.dilation[0];
905 for (int c = 0; c < IC / G; ++c) {
906 int a = w_in < 0 || w_in >= IN_DIM[0]
907 ? A_zero_point
908 : A[(n * IN_DIM[0] + w_in) * IC + g * (IC / G) + c];
909 int b =
910 B[((g * K[0] + r) * (IC / G) + c) * (OC / G) +
911 m]; // G K IC/G OC/G after transpose
912 sum += a * b;
913 } // for each c
914 } // for each r
915 C[(n * OUT_DIM[0] + w) * OC + g * (OC / G) + m] = sum;
916 } // for each w
917 } // for each m
918 } // for each group
919 } // for each n
920 }
921}
922
923// 2D Conv
924template <>
925FBGEMM_API void conv_ref(
926 const conv_param_t<2>& conv_p,
927 const uint8_t* A,
928 int32_t A_zero_point,
929 const int8_t* B,
930 int32_t* C) {
931 // filters are assumed to be in G RS C/G x K format
932 int IC = conv_p.IC;
933 int OC = conv_p.OC;
934 int G = conv_p.G;
935 assert(IC % G == 0);
936 assert(OC % G == 0);
937 array<int, 2> IN_DIM = conv_p.IN_DIM;
938 array<int, 2> OUT_DIM = conv_p.OUT_DIM;
939 array<int, 2> K = conv_p.K;
940
941 if (conv_p.transposed) {
942 // for ref implementation, there is no padding on the input buffer,
943 // padding specifies how much we remove from the output buffers
944 for (int n = 0; n < conv_p.MB; ++n) {
945 for (int oh = 0; oh < OUT_DIM[0]; ++oh) {
946 for (int ow = 0; ow < OUT_DIM[1]; ++ow) {
947 // stride on output is fractional stride on input
948 // conv index is
949 // int h_in =
950 // -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0];
951 // int w_in =
952 // -conv_p.pad[1] + w * conv_p.stride[1] + s * conv_p.dilation[1];
953 // so we reverse it
954 for (int g = 0; g < G; ++g) {
955 for (int oc = 0; oc < OC / G; ++oc) {
956 int sum = 0;
957 for (int r = 0; r < K[0]; ++r) {
958 for (int s = 0; s < K[1]; ++s) {
959 int h = oh + conv_p.pad[0] - r * conv_p.dilation[0];
960 int w = ow + conv_p.pad[1] - s * conv_p.dilation[1];
961 int h_in = h / conv_p.stride[0];
962 int w_in = w / conv_p.stride[1];
963 for (int ic = 0; ic < IC / G; ++ic) {
964 int a = (h_in * conv_p.stride[0] == h && h_in >= 0 &&
965 h_in < IN_DIM[0] && w_in * conv_p.stride[1] == w &&
966 w_in >= 0 && w_in < IN_DIM[1])
967 ? A[((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC +
968 g * (IC / G) + ic]
969 : A_zero_point;
970 int b =
971 B[((((g * K[0] + r) * K[1] + s) * (IC / G) + ic) * OC /
972 G) +
973 oc]; // G R S IC OC after transpose
974 sum += a * b;
975 } // for each ic
976 } // for each s
977 } // for each r
978 C[((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * OC + g * (OC / G) +
979 oc] = sum;
980 } // for each oc
981 } // for each g
982 } // for each w
983 } // for each h
984 } // for each n
985 } else {
986 for (int n = 0; n < conv_p.MB; ++n) {
987 for (int h = 0; h < OUT_DIM[0]; ++h) {
988 for (int w = 0; w < OUT_DIM[1]; ++w) {
989 for (int g = 0; g < G; ++g) {
990 for (int m = 0; m < OC / G; ++m) {
991 int sum = 0;
992 for (int r = 0; r < K[0]; ++r) {
993 int h_in = -conv_p.pad[0] + h * conv_p.stride[0] +
994 r * conv_p.dilation[0];
995 for (int s = 0; s < K[1]; ++s) {
996 int w_in = -conv_p.pad[1] + w * conv_p.stride[1] +
997 s * conv_p.dilation[1];
998 for (int c = 0; c < IC / G; ++c) {
999 int a = h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 ||
1000 w_in >= IN_DIM[1]
1001 ? A_zero_point
1002 : A[((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC +
1003 g * (IC / G) + c];
1004 int b =
1005 B[(((g * K[0] + r) * K[1] + s) * (IC / G) + c) *
1006 (OC / G) +
1007 m];
1008 sum += a * b;
1009 } // for each c
1010 } // for each s
1011 } // for each r
1012 C[((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * OC + g * (OC / G) +
1013 m] = sum;
1014 } // for each m
1015 } // for each group
1016 } // for each w
1017 } // for each h
1018 } // for each n
1019 }
1020}
1021
1022// 3D Conv
1023template <>
1024FBGEMM_API void conv_ref(
1025 const conv_param_t<3>& conv_p,
1026 const uint8_t* A,
1027 int32_t A_zero_point,
1028 const int8_t* B,
1029 int32_t* C) {
1030 // filters are assumed to be in G QRS C/G x K format
1031 int IC = conv_p.IC;
1032 int OC = conv_p.OC;
1033 int G = conv_p.G;
1034 assert(IC % G == 0);
1035 assert(OC % G == 0);
1036 array<int, 3> IN_DIM = conv_p.IN_DIM;
1037 array<int, 3> OUT_DIM = conv_p.OUT_DIM;
1038 array<int, 3> K = conv_p.K;
1039
1040 if (conv_p.transposed) {
1041 // for ref implementation, there is no padding on the input buffer,
1042 // padding specifies how much we remove from the output buffers
1043 for (int n = 0; n < conv_p.MB; ++n) {
1044 for (int ot = 0; ot < OUT_DIM[0]; ++ot) {
1045 for (int oh = 0; oh < OUT_DIM[1]; ++oh) {
1046 for (int ow = 0; ow < OUT_DIM[2]; ++ow) {
1047 // stride on output is fractional stride on input
1048 // conv index is
1049 // int t_in =
1050 // -conv_p.pad[0] + t * conv_p.stride[0] + q *
1051 // conv_p.dilation[0];
1052 // int h_in =
1053 // -conv_p.pad[1] + h * conv_p.stride[1] + r *
1054 // conv_p.dilation[1];
1055 // int w_in =
1056 // -conv_p.pad[2] + w * conv_p.stride[2] + s *
1057 // conv_p.dilation[2];
1058 // so we reverse it
1059 for (int g = 0; g < G; ++g) {
1060 for (int oc = 0; oc < OC / G; ++oc) {
1061 int sum = 0;
1062 for (int q = 0; q < K[0]; ++q) {
1063 for (int r = 0; r < K[1]; ++r) {
1064 for (int s = 0; s < K[2]; ++s) {
1065 int t = ot + conv_p.pad[0] - q * conv_p.dilation[0];
1066 int h = oh + conv_p.pad[1] - r * conv_p.dilation[1];
1067 int w = ow + conv_p.pad[2] - s * conv_p.dilation[2];
1068 int t_in = t / conv_p.stride[0];
1069 int h_in = h / conv_p.stride[1];
1070 int w_in = w / conv_p.stride[2];
1071 for (int ic = 0; ic < IC / G; ++ic) {
1072 int a =
1073 (t_in * conv_p.stride[0] == t && t_in >= 0 &&
1074 t_in < IN_DIM[0] && h_in * conv_p.stride[1] == h &&
1075 h_in >= 0 && h_in < IN_DIM[1] &&
1076 w_in * conv_p.stride[2] == w && w_in >= 0 &&
1077 w_in < IN_DIM[2])
1078 ? A[((((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) *
1079 IN_DIM[2]) +
1080 w_in) *
1081 IC +
1082 g * (IC / G) + ic]
1083 : A_zero_point;
1084 int b =
1085 B[((((((g * K[0] + q)) * K[1] + r) * K[2] + s) *
1086 (IC / G) +
1087 ic) *
1088 (OC / G)) +
1089 oc]; // G Q R S Cin/G Cout/G after transpose
1090 sum += a * b;
1091 } // for each ic
1092 } // for each s
1093 } // for each r
1094 } // for each q
1095 C[(((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) * OUT_DIM[2] +
1096 ow) *
1097 OC +
1098 g * (OC / G) + oc] = sum;
1099 } // for each oc
1100 } // for each g
1101 } // for each ow
1102 } // for each oh
1103 } // for each ot
1104 } // for each n
1105 } else {
1106 for (int n = 0; n < conv_p.MB; ++n) {
1107 for (int t = 0; t < OUT_DIM[0]; ++t) {
1108 for (int h = 0; h < OUT_DIM[1]; ++h) {
1109 for (int w = 0; w < OUT_DIM[2]; ++w) {
1110 for (int g = 0; g < G; ++g) {
1111 for (int m = 0; m < OC / G; ++m) {
1112 int sum = 0;
1113 for (int q = 0; q < K[0]; ++q) {
1114 int t_in = -conv_p.pad[0] + t * conv_p.stride[0] +
1115 q * conv_p.dilation[0];
1116 for (int r = 0; r < K[1]; ++r) {
1117 int h_in = -conv_p.pad[1] + h * conv_p.stride[1] +
1118 r * conv_p.dilation[1];
1119 for (int s = 0; s < K[2]; ++s) {
1120 int w_in = -conv_p.pad[2] + w * conv_p.stride[2] +
1121 s * conv_p.dilation[2];
1122 for (int c = 0; c < IC / G; ++c) {
1123 int a = t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 ||
1124 h_in >= IN_DIM[1] || w_in < 0 ||
1125 w_in >= IN_DIM[2]
1126 ? A_zero_point
1127 : A[(((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) *
1128 IN_DIM[2] +
1129 w_in) *
1130 IC +
1131 g * (IC / G) + c];
1132 int b =
1133 B[((((g * K[0] + q) * K[1] + r) * K[2] + s) *
1134 (IC / G) +
1135 c) *
1136 (OC / G) +
1137 m];
1138 sum += a * b;
1139 } // for each c
1140 } // for each s
1141 } // for each r
1142 } // for each q
1143 C[(((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) * OUT_DIM[2] + w) *
1144 OC +
1145 g * (OC / G) + m] = sum;
1146 } // for each m
1147 } // for each group
1148 } // for each w
1149 } // for each h
1150 } // for each t
1151 } // for each n
1152 }
1153}
1154
1155template <int SPATIAL_DIM>
1156void transposeConvWeights(
1157 const conv_param_t<SPATIAL_DIM>& conv_p,
1158 const std::int8_t* src,
1159 std::int8_t* dest) {
1160 int G = conv_p.G;
1161 int IC_per_G = conv_p.IC / conv_p.G;
1162 int OC_per_G = conv_p.OC / conv_p.G;
1163
1164 int filter_prod = std::accumulate(
1165 conv_p.K.begin(),
1166 conv_p.K.begin() + SPATIAL_DIM,
1167 1,
1168 std::multiplies<int>());
1169 // Transforms weights from G K/G (T R S C/G) to G (T R S C/G) K/G format.
1170 for (int g = 0; g < G; ++g) {
1171 for (int k = 0; k < OC_per_G; ++k) {
1172 for (int f = 0; f < filter_prod; ++f) {
1173 for (int c = 0; c < IC_per_G; ++c) {
1174 dest[((g * filter_prod + f) * IC_per_G + c) * OC_per_G + k] =
1175 src[((g * OC_per_G + k) * filter_prod + f) * IC_per_G + c];
1176 }
1177 }
1178 }
1179 }
1180}
1181
1182template float convert_to_float_ref(float src, bool is_bf16);
1183template float convert_to_float_ref(uint16_t src, bool is_bf16);
1184template float convert_from_float_ref(float src, bool is_bf16);
1185template uint16_t convert_from_float_ref(float bfloat16, bool is_bf16);
1186
1187template <
1188 typename InType,
1189 typename IndexType,
1190 typename OffsetType,
1191 typename OutType>
1192bool EmbeddingSpMDM_ref(
1193 const int64_t block_size,
1194 const int64_t output_size,
1195 const int64_t index_size,
1196 const int64_t data_size,
1197 const InType* input,
1198 const IndexType* indices,
1199 const OffsetType* offsets_or_lengths,
1200 const float* weights, // optional, can be null for non-weighted sum
1201 bool normalize_by_lengths,
1202 OutType* out,
1203 bool is_weight_positional,
1204 bool use_offsets,
1205 int64_t output_stride /*=-1*/,
1206 int64_t input_stride /*=-1*/,
1207 bool scale_bias_last,
1208 bool no_bag,
1209 bool is_bf16 /*=false*/) {
1210 bool is8bit = is_same<InType, uint8_t>::value;
1211 if (output_stride == -1) {
1212 output_stride = block_size;
1213 }
1214
1215 vector<float> buf(block_size);
1216
1217 if (is8bit) {
1218 // block_size is the number of elements and fused_block_size is the size of
1219 // an entire row, including scale and bias.
1220 if (input_stride == -1) {
1221 // scale_bias_last == false is for table batched embedding that stores
1222 // scale and bias in float16
1223 const auto scale_bias_offset =
1224 2 * (scale_bias_last ? sizeof(float) : sizeof(float16));
1225 input_stride = block_size + scale_bias_offset;
1226 }
1227 int64_t current = 0;
1228
1229 if (no_bag) {
1230 for (int m = 0; m < output_size; ++m) {
1231 memset(buf.data(), 0, sizeof(float) * block_size);
1232 int64_t idx = indices[m];
1233
1234 if (idx < 0 || idx >= data_size) {
1235 return false;
1236 }
1237
1238 const float* scale_bias = reinterpret_cast<const float*>(
1239 input + input_stride * idx + (scale_bias_last ? block_size : 0));
1240
1241 float weight = 1.0f;
1242 if (weights) {
1243 weight = weights[m];
1244 }
1245
1246 float scale, bias;
1247 if (scale_bias_last) {
1248 scale = weight * scale_bias[0];
1249 bias = weight * scale_bias[1];
1250 } else {
1251 scale = weight *
1252 cpu_half2float(reinterpret_cast<const float16*>(scale_bias)[0]);
1253 bias = weight *
1254 cpu_half2float(reinterpret_cast<const float16*>(scale_bias)[1]);
1255 }
1256
1257 for (int j = 0; j < block_size; ++j) {
1258 buf[j] = std::fma(
1259 scale,
1260 input
1261 [input_stride * idx + j +
1262 (scale_bias_last ? 0 : 2 * sizeof(float16))],
1263 buf[j] + bias);
1264 }
1265 for (int j = 0; j < block_size; ++j) {
1266 out[j] = is_same<OutType, float16>::value ? cpu_float2half_rn(buf[j])
1267 : buf[j];
1268 }
1269 out += output_stride;
1270 } // m
1271 return true;
1272 } // no_bag
1273
1274 for (int m = 0; m < output_size; ++m) {
1275 memset(buf.data(), 0, sizeof(float) * block_size);
1276 int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
1277 : offsets_or_lengths[m];
1278 if (current + len > index_size) {
1279 return false;
1280 }
1281 for (int i = 0; i < len; ++i) {
1282 int64_t idx = indices[current];
1283 if (idx < 0 || idx >= data_size) {
1284 return false;
1285 }
1286
1287 const float* scale_bias = reinterpret_cast<const float*>(
1288 input + input_stride * idx + (scale_bias_last ? block_size : 0));
1289
1290 float weight = 1.0f;
1291 if (weights) {
1292 weight = weights[is_weight_positional ? i : current];
1293 }
1294 float scale, bias;
1295 if (scale_bias_last) {
1296 scale = weight * scale_bias[0];
1297 bias = weight * scale_bias[1];
1298 } else {
1299 scale = weight *
1300 cpu_half2float(reinterpret_cast<const float16*>(scale_bias)[0]);
1301 bias = weight *
1302 cpu_half2float(reinterpret_cast<const float16*>(scale_bias)[1]);
1303 }
1304
1305 for (int j = 0; j < block_size; ++j) {
1306 buf[j] = std::fma(
1307 scale,
1308 input
1309 [input_stride * idx + j +
1310 (scale_bias_last ? 0 : 2 * sizeof(float16))],
1311 buf[j] + bias);
1312 }
1313
1314 ++current;
1315 }
1316 if (normalize_by_lengths && len) {
1317 float scale = 1.f / len;
1318 for (int j = 0; j < block_size; ++j) {
1319 buf[j] *= scale;
1320 }
1321 }
1322 for (int j = 0; j < block_size; ++j) {
1323 out[j] = is_same<OutType, float16>::value ? cpu_float2half_rn(buf[j])
1324 : buf[j];
1325 }
1326 out += output_stride;
1327 }
1328 return current == index_size;
1329 } else {
1330 if (input_stride == -1) {
1331 input_stride = block_size;
1332 }
1333
1334 if (no_bag) {
1335 for (int m = 0; m < output_size; ++m) {
1336 memset(buf.data(), 0, sizeof(float) * block_size);
1337 int64_t idx = indices[m];
1338 if (idx < 0 || idx >= data_size) {
1339 return false;
1340 }
1341
1342 float w = 1.f;
1343 if (weights) {
1344 w = weights[m];
1345 }
1346
1347 for (int j = 0; j < block_size; ++j) {
1348 const InType* inptr = input + input_stride * idx + j;
1349 buf[j] = std::fma(w, convert_to_float_ref(*inptr, is_bf16), buf[j]);
1350 }
1351 for (int j = 0; j < block_size; ++j) {
1352 out[j] = convert_from_float_ref<OutType>(buf[j], is_bf16);
1353 }
1354 out += output_stride;
1355 } // m
1356 return true;
1357 } // no_bag
1358
1359 // Reference implementation of FP32 SLS
1360 int64_t current = 0;
1361 for (int m = 0; m < output_size; ++m) {
1362 memset(buf.data(), 0, sizeof(float) * block_size);
1363 int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
1364 : offsets_or_lengths[m];
1365 if (current + len > index_size) {
1366 return false;
1367 }
1368 for (int i = 0; i < len; ++i) {
1369 int64_t idx = indices[current];
1370 if (idx < 0 || idx >= data_size) {
1371 return false;
1372 }
1373
1374 float w = 1.f;
1375 if (weights) {
1376 w = weights[is_weight_positional ? i : current];
1377 }
1378
1379 for (int j = 0; j < block_size; ++j) {
1380 const InType* inptr = input + input_stride * idx + j;
1381 buf[j] = std::fma(w, convert_to_float_ref(*inptr, is_bf16), buf[j]);
1382 }
1383
1384 ++current;
1385 }
1386 if (normalize_by_lengths && len) {
1387 float scale = 1.f / len;
1388 for (int j = 0; j < block_size; ++j) {
1389 buf[j] *= scale;
1390 }
1391 }
1392 for (int j = 0; j < block_size; ++j) {
1393 out[j] = convert_from_float_ref<OutType>(buf[j], is_bf16);
1394 }
1395 out += output_stride;
1396 }
1397 return current == index_size;
1398 }
1399}
1400
1401template <typename IndexType, typename OffsetType, typename OutType>
1402bool EmbeddingSpMDMNBit_ref(
1403 int bit_rate,
1404 const int64_t block_size,
1405 const int64_t output_size,
1406 const int64_t index_size,
1407 const int64_t data_size,
1408 const uint8_t* input,
1409 const IndexType* indices,
1410 const OffsetType* offsets_or_lengths,
1411 const float* weights, // optional, can be null for non-weighted sum
1412 bool normalize_by_lengths,
1413 OutType* out,
1414 bool is_weight_positional,
1415 bool use_offsets,
1416 int64_t output_stride,
1417 int64_t input_stride,
1418 bool scale_bias_last) {
1419 assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4");
1420 int num_elem_per_byte = 8 / bit_rate;
1421
1422 if (output_stride == -1) {
1423 output_stride = block_size;
1424 }
1425
1426 // block_size is the number of elements and fused_block_size is the size of
1427 // an entire row, including scale and bias.
1428 const auto scale_bias_offset = 2 * sizeof(float16);
1429 if (input_stride == -1) {
1430 input_stride = (block_size + num_elem_per_byte - 1) / num_elem_per_byte +
1431 scale_bias_offset;
1432 }
1433 int64_t current = 0;
1434 vector<float> buf(block_size);
1435 for (int m = 0; m < output_size; ++m) {
1436 memset(buf.data(), 0, sizeof(float) * block_size);
1437 int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
1438 : offsets_or_lengths[m];
1439 if (current + len > index_size) {
1440 return false;
1441 }
1442 for (int i = 0; i < len; ++i) {
1443 int64_t idx = indices[current];
1444 if (idx < 0 || idx >= data_size) {
1445 return false;
1446 }
1447
1448 const float16* scale_bias = reinterpret_cast<const float16*>(
1449 input + input_stride * idx +
1450 (scale_bias_last
1451 ? (block_size + num_elem_per_byte - 1) / num_elem_per_byte
1452 : 0));
1453
1454 float weight = 1.0f;
1455 if (weights) {
1456 weight = weights[is_weight_positional ? i : current];
1457 }
1458 const float scale = weight * cpu_half2float(scale_bias[0]);
1459 const float bias = weight * cpu_half2float(scale_bias[1]);
1460
1461 for (int j = 0; j < block_size; ++j) {
1462 uint8_t quantized = input
1463 [input_stride * idx + j / num_elem_per_byte +
1464 (scale_bias_last ? 0 : scale_bias_offset)];
1465 quantized >>= (j % num_elem_per_byte) * bit_rate;
1466 quantized &= (1 << bit_rate) - 1;
1467
1468 buf[j] = std::fma(scale, quantized, buf[j] + bias);
1469 }
1470
1471 ++current;
1472 }
1473 if (normalize_by_lengths && len) {
1474 float scale = 1.f / len;
1475 for (int j = 0; j < block_size; ++j) {
1476 buf[j] *= scale;
1477 }
1478 }
1479 for (int j = 0; j < block_size; ++j) {
1480 out[j] = std::is_same<OutType, float16>::value ? cpu_float2half_rn(buf[j])
1481 : buf[j];
1482 }
1483 out += output_stride;
1484 }
1485 return current == index_size;
1486}
1487
1488template <typename IndexType, typename OffsetType, typename OutType>
1489bool EmbeddingSpMDMFP8_ref(
1490 const int64_t block_size,
1491 const int64_t output_size,
1492 const int64_t index_size,
1493 const int64_t data_size,
1494 const uint8_t* input,
1495 const IndexType* indices,
1496 const OffsetType* offsets_or_lengths,
1497 const float* weights,
1498 bool normalize_by_lengths,
1499 OutType* out,
1500 bool is_weight_positional,
1501 bool use_offsets,
1502 int64_t output_stride,
1503 int64_t input_stride,
1504 int exponent_bits,
1505 int exponent_bias) {
1506 if (output_stride == -1) {
1507 output_stride = block_size;
1508 }
1509
1510 vector<float> buf(block_size);
1511
1512 if (input_stride == -1) {
1513 input_stride = block_size;
1514 }
1515
1516 // Reference implementation of FP8 SLS. The algorithm is similar to FP32 SLS
1517 // except for the FP8->FP32 conversion after reading the embedding weight.
1518 int64_t current = 0;
1519 for (int m = 0; m < output_size; ++m) {
1520 memset(buf.data(), 0, sizeof(float) * block_size);
1521 int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
1522 : offsets_or_lengths[m];
1523 if (current + len > index_size) {
1524 return false;
1525 }
1526 for (int i = 0; i < len; ++i) {
1527 int64_t idx = indices[current];
1528 if (idx < 0 || idx >= data_size) {
1529 return false;
1530 }
1531
1532 float w = 1.f;
1533 if (weights) {
1534 w = weights[is_weight_positional ? i : current];
1535 }
1536
1537 for (int j = 0; j < block_size; ++j) {
1538 const uint8_t* inptr = input + input_stride * idx + j;
1539 float input_f;
1540 // Dequantize FP8 to FP32 before compute
1541 Float8ToFloat_ref(*inptr, &input_f, exponent_bits, exponent_bias);
1542 buf[j] = std::fma(w, input_f, buf[j]);
1543 }
1544
1545 ++current;
1546 }
1547 if (normalize_by_lengths && len) {
1548 float scale = 1.f / len;
1549 for (int j = 0; j < block_size; ++j) {
1550 buf[j] *= scale;
1551 }
1552 }
1553 for (int j = 0; j < block_size; ++j) {
1554 out[j] =
1555 is_same<OutType, float16>::value ? cpu_float2half_rn(buf[j]) : buf[j];
1556 }
1557 out += output_stride;
1558 }
1559 return current == index_size;
1560}
1561
1562template <typename InType, typename IndexType, typename OffsetType>
1563bool EmbeddingSpMDMRowWiseSparse_ref(
1564 const int64_t block_size,
1565 const int64_t output_size,
1566 const int64_t index_size,
1567 const int64_t uncompressed_data_size,
1568 // const int64_t compressed_data_size,
1569 const InType* input,
1570 const IndexType* indices,
1571 const int32_t* compressed_indices_table,
1572 const OffsetType* offsets_or_lengths,
1573 const float* weights, // optional, can be null for non-weighted sum
1574 bool normalize_by_lengths,
1575 float* out,
1576 bool is_weight_positional,
1577 bool use_offsets) {
1578 bool is8bit = is_same<InType, uint8_t>::value;
1579
1580 if (is8bit) {
1581 // block_size is the number of elements and fused_block_size is the size
1582 // of an entire row, including scale and bias.
1583 const auto scale_bias_offset = 2 * sizeof(float);
1584 const int64_t fused_block_size = block_size + scale_bias_offset;
1585 int64_t current = 0;
1586 for (int m = 0; m < output_size; ++m) {
1587 memset(out, 0, sizeof(float) * block_size);
1588 int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
1589 : offsets_or_lengths[m];
1590 if (current + len > index_size) {
1591 return false;
1592 }
1593 for (int i = 0; i < len; ++i) {
1594 IndexType uncompressed_idx = indices[current];
1595 if (uncompressed_idx < 0 ||
1596 uncompressed_idx >= uncompressed_data_size) {
1597 return false;
1598 }
1599 IndexType idx = compressed_indices_table[uncompressed_idx];
1600 if (idx == -1) {
1601 ++current;
1602 continue;
1603 }
1604 // if (idx < 0 || idx >= compressed_data_size) {
1605 // return false;
1606 // }
1607
1608 const float* scale_bias = reinterpret_cast<const float*>(
1609 input + fused_block_size * idx + block_size);
1610
1611 float weight = 1.0f;
1612 if (weights) {
1613 weight = weights[is_weight_positional ? i : current];
1614 }
1615 const float scale = weight * scale_bias[0];
1616 const float bias = weight * scale_bias[1];
1617
1618 for (int j = 0; j < block_size; ++j) {
1619 out[j] =
1620 std::fma(scale, input[fused_block_size * idx + j], out[j] + bias);
1621 }
1622
1623 ++current;
1624 }
1625 if (normalize_by_lengths && len) {
1626 float scale = 1.f / len;
1627 for (int j = 0; j < block_size; ++j) {
1628 out[j] *= scale;
1629 }
1630 }
1631 out += block_size;
1632 }
1633 return current == index_size;
1634 } else {
1635 // Reference implementation of FP32 SLS
1636 int64_t current = 0;
1637 for (int m = 0; m < output_size; ++m) {
1638 memset(out, 0, sizeof(float) * block_size);
1639 int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
1640 : offsets_or_lengths[m];
1641 if (current + len > index_size) {
1642 return false;
1643 }
1644 for (int i = 0; i < len; ++i) {
1645 IndexType uncompressed_idx = indices[current];
1646 if (uncompressed_idx < 0 ||
1647 uncompressed_idx >= uncompressed_data_size) {
1648 return false;
1649 }
1650 IndexType idx = compressed_indices_table[uncompressed_idx];
1651 if (idx == -1) {
1652 ++current;
1653 continue;
1654 }
1655 // if (idx < 0 || idx >= compressed_data_size) {
1656 // return false;
1657 // }
1658
1659 float w = 1.f;
1660 if (weights) {
1661 w = weights[is_weight_positional ? i : current];
1662 }
1663
1664 for (int j = 0; j < block_size; ++j) {
1665 const InType* inptr = input + block_size * idx + j;
1666 out[j] = std::fma(
1667 w,
1668 is_same<InType, float16>::value ? cpu_half2float(*inptr) : *inptr,
1669 out[j]);
1670 }
1671
1672 ++current;
1673 }
1674 if (normalize_by_lengths && len) {
1675 float scale = 1.f / len;
1676 for (int j = 0; j < block_size; ++j) {
1677 out[j] *= scale;
1678 }
1679 }
1680 out += block_size;
1681 }
1682 return current == index_size;
1683 }
1684}
1685
1686template <typename IndexType, typename OffsetType>
1687bool EmbeddingSpMDMNBitRowWiseSparse_ref(
1688 int bit_rate,
1689 const int64_t block_size,
1690 const int64_t output_size,
1691 const int64_t index_size,
1692 const int64_t uncompressed_data_size,
1693 // const int64_t compressed_data_size,
1694 const uint8_t* input,
1695 const IndexType* indices,
1696 const int32_t* compressed_indices_table,
1697 const OffsetType* offsets_or_lengths,
1698 const float* weights, // optional, can be null for non-weighted sum
1699 bool normalize_by_lengths,
1700 float* out,
1701 bool is_weight_positional,
1702 bool use_offsets) {
1703 assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4");
1704 int num_elem_per_byte = 8 / bit_rate;
1705
1706 // block_size is the number of elements and fused_block_size is the size of
1707 // an entire row, including scale and bias.
1708 const auto scale_bias_offset = 2 * sizeof(float16);
1709 const int64_t fused_block_size =
1710 (block_size + num_elem_per_byte - 1) / num_elem_per_byte +
1711 scale_bias_offset;
1712 int64_t current = 0;
1713 for (int m = 0; m < output_size; ++m) {
1714 memset(out, 0, sizeof(float) * block_size);
1715 int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
1716 : offsets_or_lengths[m];
1717 if (current + len > index_size) {
1718 return false;
1719 }
1720 for (int i = 0; i < len; ++i, ++current) {
1721 IndexType uncompressed_idx = indices[current];
1722 if (uncompressed_idx < 0 || uncompressed_idx >= uncompressed_data_size) {
1723 return false;
1724 }
1725 IndexType idx = compressed_indices_table[uncompressed_idx];
1726 if (idx == -1) {
1727 continue;
1728 }
1729 // if (idx < 0 || idx >= compressed_data_size) {
1730 // return false;
1731 // }
1732
1733 const float16* scale_bias = reinterpret_cast<const float16*>(
1734 input + fused_block_size * idx +
1735 (block_size + num_elem_per_byte - 1) / num_elem_per_byte);
1736
1737 float weight = 1.0f;
1738 if (weights) {
1739 weight = weights[is_weight_positional ? i : current];
1740 }
1741 const float scale = weight * cpu_half2float(scale_bias[0]);
1742 const float bias = weight * cpu_half2float(scale_bias[1]);
1743
1744 for (int j = 0; j < block_size; ++j) {
1745 uint8_t quantized =
1746 input[fused_block_size * idx + j / num_elem_per_byte];
1747 quantized >>= (j % num_elem_per_byte) * bit_rate;
1748 quantized &= (1 << bit_rate) - 1;
1749
1750 out[j] = std::fma(scale, quantized, out[j] + bias);
1751 }
1752 }
1753 if (normalize_by_lengths && len) {
1754 float scale = 1.f / len;
1755 for (int j = 0; j < block_size; ++j) {
1756 out[j] *= scale;
1757 }
1758 }
1759 out += block_size;
1760 }
1761 return current == index_size;
1762}
1763
1764template <typename IndexType>
1765int sparse_adagrad_ref(
1766 int num_rows, // number of rows reading
1767 int block_size, // number of parameters per rows
1768 uint64_t param_size, // total number of parameters
1769 float* w, // input parameters
1770 const float* g, // input gradients
1771 float* h, // input momentums
1772 const IndexType* indices, // indices of each row
1773 float epsilon,
1774 float lr,
1775 float weight_decay,
1776 const double* counter,
1777 const int64_t counter_halflife) {
1778 for (auto i = 0; i < num_rows; ++i) {
1779 uint64_t idx = indices[i];
1780 auto offsetI = i * block_size;
1781 auto offsetIdx = idx * block_size;
1782
1783 if (block_size + offsetIdx > param_size) {
1784 return i;
1785 }
1786
1787 float freq =
1788 (counter && counter[idx] > 0) ? counter_halflife / counter[idx] : 1.0;
1789
1790 const float* g_;
1791 const float* h_;
1792 const float* w_;
1793 float* nh_;
1794 float* nw_;
1795
1796 g_ = g + offsetI;
1797 h_ = h + offsetIdx;
1798 w_ = w + offsetIdx;
1799 nh_ = h + offsetIdx;
1800 nw_ = w + offsetIdx;
1801
1802 for (auto j = 0; j < block_size; ++j) {
1803 float gj = std::fma(weight_decay * freq, w_[j], g_[j]);
1804 float hj = h_[j] + gj * gj;
1805 nh_[j] = hj;
1806 nw_[j] = w_[j] + lr * gj / (std::sqrt(hj) + epsilon);
1807 }
1808 }
1809 return num_rows;
1810}
1811
1812template <typename IndexType>
1813int rowwise_sparse_adagrad_ref(
1814 int num_rows, // number of rows reading
1815 int block_size, // number of parameters per rows
1816 uint64_t param_size, // total number of parameters
1817 float* w, // input parameters
1818 const float* g, // input gradients
1819 float* h, // input momentums
1820 const IndexType* indices, // indices of each row
1821 float epsilon,
1822 float lr,
1823 float weight_decay,
1824 const double* counter,
1825 const int64_t counter_halflife) {
1826 for (auto i = 0; i < num_rows; ++i) {
1827 uint64_t idx = indices[i];
1828 auto offsetI = i * block_size;
1829 auto offsetIdx = idx * block_size;
1830
1831 if (block_size + offsetIdx > param_size) {
1832 return i;
1833 }
1834
1835 float freq =
1836 (counter && counter[idx] > 0) ? counter_halflife / counter[idx] : 1.0;
1837
1838 const float* g_;
1839 float* h_;
1840 float* w_;
1841
1842 g_ = g + offsetI;
1843 h_ = h + idx; // This is different from sparse adagrad
1844 w_ = w + offsetIdx;
1845
1846 float final_sum = 0.0f;
1847 // Note the following code assumes fbgemm will generate AVX2 code for
1848 // horizontal reduction, which is OK for now because fbgemm always uses
1849 // AVX2 for SparseAdagrad due to its performance is bounded by memory
1850 // bandwidth hence no speedup from AVX512. Non-vectorized version would be
1851 // just for (auto j = 0; j < block_size; ++j) {
1852 // float gj = g_[j];
1853 // final_sum += gj * gj;
1854 // }
1855 constexpr int VLEN = 8;
1856 array<float, VLEN> partial_sum = {0.0f};
1857 for (auto j = 0; j < block_size; ++j) {
1858 float gj = std::fma(weight_decay * freq, w_[j], g_[j]);
1859 partial_sum[j % VLEN] += gj * gj;
1860 }
1861 final_sum = ((partial_sum[0] + partial_sum[1]) +
1862 (partial_sum[2] + partial_sum[3])) +
1863 ((partial_sum[4] + partial_sum[5]) + (partial_sum[6] + partial_sum[7]));
1864 final_sum /= block_size;
1865 float hi = *h_ = *h_ + final_sum;
1866 float float_step = lr / (std::sqrt(hi) + epsilon);
1867
1868 for (auto j = 0; j < block_size; ++j) {
1869 float gj = std::fma(weight_decay * freq, w_[j], g_[j]);
1870 w_[j] += gj * float_step;
1871 }
1872 }
1873 return num_rows;
1874}
1875
1876template <typename DataType, typename IndexType, typename OffsetType>
1877int rowwise_sparse_adagrad_fused_ref(
1878 int64_t block_size,
1879 int64_t output_size,
1880 int64_t index_size,
1881 int64_t data_size,
1882 DataType* w,
1883 const float* g,
1884 float* h,
1885 const IndexType* indices,
1886 const OffsetType* offsets_or_lengths,
1887 float epsilon,
1888 float lr,
1889 bool use_offsets,
1890 bool use_stochastic_rounding,
1891 int emu_vector_size,
1892 int64_t grad_stride) {
1893 if (grad_stride == -1) {
1894 grad_stride = block_size;
1895 }
1896
1897 constexpr bool isFloat16w = std::is_same<float16, DataType>::value;
1898 // Local random buffer to emulate SIMD vector
1899 // R: generated 32bit base random numbers
1900 // r: extracted 8-bit for rounding
1901 constexpr int VLEN_MAX = 16;
1902 uint32_t R[VLEN_MAX], r[VLEN_MAX];
1903 int vlen = emu_vector_size;
1904 if (vlen != 8 && vlen != 16) {
1905 // Raise error as it may cause buffer overflow
1906 cerr << "Not supported emu_vector_size: " << emu_vector_size << endl;
1907 return 0;
1908 }
1909
1910 int64_t current = 0;
1911 for (int m = 0; m < output_size; ++m) {
1912 int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
1913 : offsets_or_lengths[m];
1914 if (current + len > index_size) {
1915 return false;
1916 }
1917 const float* g_ = g + m * grad_stride;
1918 // Note the following code assumes fbgemm will generate AVX2 code for
1919 // horizontal reduction, which is OK for now because fbgemm always uses
1920 // AVX2 for SparseAdagrad due to its performance is bounded by memory
1921 // bandwidth hence no speedup from AVX512. Non-vectorized version would be
1922 // just for (auto j = 0; j < block_size; ++j) {
1923 // float gj = g_[j];
1924 // final_sum += gj * gj;
1925 // }
1926 constexpr int VLEN_AVX2 = 8;
1927 array<float, VLEN_AVX2> partial_sum = {0.0f};
1928 for (auto j = 0; j < block_size; ++j) {
1929 float gj = g_[j];
1930 partial_sum[j % VLEN_AVX2] += gj * gj;
1931 }
1932 float final_sum = ((partial_sum[0] + partial_sum[1]) +
1933 (partial_sum[2] + partial_sum[3])) +
1934 ((partial_sum[4] + partial_sum[5]) + (partial_sum[6] + partial_sum[7]));
1935 final_sum /= block_size;
1936
1937 for (int i = 0; i < len; ++i, ++current) {
1938 int64_t idx = indices[current];
1939 if (idx < 0 || idx >= data_size) {
1940 return false;
1941 }
1942
1943 float* h_ = h + idx;
1944 DataType* w_ = w + idx * block_size;
1945
1946 float hi = *h_ = *h_ + final_sum;
1947 float float_step = lr / (std::sqrt(hi) + epsilon);
1948
1949 int nvec = (block_size + vlen - 1) / vlen;
1950 int rem = (block_size % vlen) ? (block_size % vlen) : vlen;
1951
1952 // Emulate JIT behavior of stochastic rounding with vector-length
1953 //
1954 // Generate R buffer every 4 steps of nvec loop. Each 8-bit in R
1955 // (uint32_t) will be used once. It is shifted to bits[5..13] then
1956 // added to FP32 weights before FP16 conversion.
1957 //
1958 // The shifted 8 bit region
1959 // +-------+--------+--------+--------+
1960 // | | | xxxxx|xxx |
1961 // 31 23 15 7 0
1962 //
1963 // Half float has 10 bits of mantissa, and float has 23, we are shifting
1964 // the bits to cover the region where half floats can't represent data.
1965 // This is bit 13-23 of the mantissa of fp32.
1966 // This will be effectively adding a random variable of [0,1]
1967
1968 for (int n = 0; n < nvec; ++n) {
1969 int cur_vlen = (n == nvec - 1) ? rem : vlen;
1970 int sr_idx = n % 4;
1971
1972 if (isFloat16w && use_stochastic_rounding) {
1973 if (sr_idx == 0) {
1974 for (int v = 0; v < vlen; ++v) {
1975 R[v] = rnd128_next(v, vlen);
1976 r[v] = (R[v] & 0xFFU) << 5;
1977 }
1978 } else if (sr_idx == 1) {
1979 for (int v = 0; v < vlen; ++v) {
1980 r[v] = ((R[v] & 0xFF00U) >> 8) << 5;
1981 }
1982 } else if (sr_idx == 2) {
1983 for (int v = 0; v < vlen; ++v) {
1984 r[v] = ((R[v] & 0xFF0000U) >> 16) << 5;
1985 }
1986 } else { // 3
1987 for (int v = 0; v < vlen; ++v) {
1988 r[v] = ((R[v] & 0xFF000000U) >> 24) << 5;
1989 }
1990 }
1991 }
1992
1993 for (int v = 0; v < cur_vlen; ++v) {
1994 int j = n * vlen + v;
1995 if (isFloat16w) {
1996 union {
1997 float w_f32;
1998 uint32_t w_i32;
1999 };
2000 w_f32 = cpu_half2float(w_[j]);
2001 w_f32 = std::fma(float_step, g_[j], w_f32);
2002 if (use_stochastic_rounding) {
2003 w_i32 += r[v];
2004 }
2005 // Use truncate rounding to 'counterwork' the random added part
2006 w_[j] = cpu_float2half_rz(w_f32);
2007 } else { // float
2008 w_[j] += g_[j] * float_step;
2009 }
2010 }
2011 }
2012 }
2013 }
2014
2015 return current == index_size;
2016}
2017
2018template FBGEMM_API void transposeConvWeights(
2019 const conv_param_t<1>& conv_p,
2020 const std::int8_t* src,
2021 std::int8_t* dest);
2022
2023template FBGEMM_API void transposeConvWeights(
2024 const conv_param_t<2>& conv_p,
2025 const std::int8_t* src,
2026 std::int8_t* dest);
2027
2028template FBGEMM_API void transposeConvWeights(
2029 const conv_param_t<3>& conv_p,
2030 const std::int8_t* src,
2031 std::int8_t* dest);
2032
2033#define INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
2034 template FBGEMM_API bool EmbeddingSpMDM_ref( \
2035 const int64_t block_size, \
2036 const int64_t output_size, \
2037 const int64_t index_size, \
2038 const int64_t data_size, \
2039 const IN_TYPE* input, \
2040 const INDEX_TYPE* indices, \
2041 const OFFSET_TYPE* offsets_or_lengths, \
2042 const float* weights, \
2043 bool normalize_by_lengths, \
2044 OUT_TYPE* out, \
2045 bool is_weight_positional, \
2046 bool use_offsets, \
2047 int64_t input_stride, \
2048 int64_t output_stride, \
2049 bool scale_bias_last, \
2050 bool no_bag, \
2051 bool is_bf16);
2052
2053#define INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \
2054 INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float) \
2055 INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float16) \
2056 template FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( \
2057 const int64_t block_size, \
2058 const int64_t output_size, \
2059 const int64_t index_size, \
2060 const int64_t uncompressed_data_size, \
2061 const IN_TYPE* input, \
2062 const INDEX_TYPE* indices, \
2063 const int32_t* compressed_indices_table, \
2064 const OFFSET_TYPE* offsets_or_lengths, \
2065 const float* weights, \
2066 bool normalize_by_lengths, \
2067 float* out, \
2068 bool is_weight_positional, \
2069 bool use_offsets);
2070
2071#define INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, INDEX_TYPE) \
2072 INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, std::int32_t) \
2073 INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, std::int64_t)
2074
2075#define INSTANTIATE_SPMDM_INDEX_T(IN_TYPE) \
2076 INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, std::int32_t) \
2077 INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, std::int64_t)
2078
2079INSTANTIATE_SPMDM_INDEX_T(float)
2080INSTANTIATE_SPMDM_INDEX_T(float16)
2081INSTANTIATE_SPMDM_INDEX_T(std::uint8_t)
2082
2083#undef INSTANTIATE_SPMDM_INDEX_T
2084#undef INSTANTIATE_SPMDM_OFFSET_T
2085#undef INSTANTIATE_SPMDM_OUT_T
2086#undef INSTANTIATE_SPMDM_BASE
2087
2088#define INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
2089 template FBGEMM_API bool EmbeddingSpMDMNBit_ref( \
2090 int bit_rate, \
2091 const int64_t block_size, \
2092 const int64_t output_size, \
2093 const int64_t index_size, \
2094 const int64_t data_size, \
2095 const uint8_t* input, \
2096 const INDEX_TYPE* indices, \
2097 const OFFSET_TYPE* offsets_or_lengths, \
2098 const float* weights, \
2099 bool normalize_by_lengths, \
2100 OUT_TYPE* out, \
2101 bool is_weight_positional, \
2102 bool use_offsets, \
2103 int64_t output_stride, \
2104 int64_t input_stride, \
2105 bool scale_bias_last); \
2106 template FBGEMM_API bool EmbeddingSpMDMFP8_ref( \
2107 const int64_t block_size, \
2108 const int64_t output_size, \
2109 const int64_t index_size, \
2110 const int64_t data_size, \
2111 const uint8_t* input, \
2112 const INDEX_TYPE* indices, \
2113 const OFFSET_TYPE* offsets_or_lengths, \
2114 const float* weights, \
2115 bool normalize_by_lengths, \
2116 OUT_TYPE* out, \
2117 bool is_weight_positional, \
2118 bool use_offsets, \
2119 int64_t output_stride, \
2120 int64_t input_stride, \
2121 int exponent_bits, \
2122 int exponent_bias);
2123
2124#define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \
2125 INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float) \
2126 INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \
2127 template FBGEMM_API bool EmbeddingSpMDMNBitRowWiseSparse_ref( \
2128 int bit_rate, \
2129 const int64_t block_size, \
2130 const int64_t output_size, \
2131 const int64_t index_size, \
2132 const int64_t uncompressed_data_size, \
2133 const uint8_t* input, \
2134 const INDEX_TYPE* indices, \
2135 const int32_t* compressed_indices_table, \
2136 const OFFSET_TYPE* offsets_or_lengths, \
2137 const float* weights, \
2138 bool normalize_by_lengths, \
2139 float* out, \
2140 bool is_weight_positional, \
2141 bool use_offsets);
2142
2143#define INSTANTIATE_SPMDM_OFFSET_T(INDEX_TYPE) \
2144 INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int32_t) \
2145 INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int64_t)
2146
2147INSTANTIATE_SPMDM_OFFSET_T(int32_t)
2148INSTANTIATE_SPMDM_OFFSET_T(int64_t)
2149
2150#undef INSTANTIATE_SPMDM_OFFSET_T
2151#undef INSTANTIATE_SPMDM_OUT_T
2152#undef INSTANTIATE_SPMDM_BASE
2153
2154template FBGEMM_API int sparse_adagrad_ref(
2155 int num_rows, // number of rows reading
2156 int block_size, // number of parameters per rows
2157 std::uint64_t param_size, // total number of parameters
2158 float* w, // input parameters
2159 const float* g, // input gradients
2160 float* h, // input momentums
2161 const std::int64_t* indices, // indices of each row
2162 float epsilon,
2163 float lr,
2164 float weight_decay,
2165 const double* counter,
2166 const int64_t counter_halflife);
2167
2168template FBGEMM_API int sparse_adagrad_ref(
2169 int num_rows, // number of rows reading
2170 int block_size, // number of parameters per rows
2171 std::uint64_t param_size, // total number of parameters
2172 float* w, // input parameters
2173 const float* g, // input gradients
2174 float* h, // input momentums
2175 const std::int32_t* indices, // indices of each row
2176 float epsilon,
2177 float lr,
2178 float weight_decay,
2179 const double* counter,
2180 const int64_t counter_halflife);
2181
2182template FBGEMM_API int rowwise_sparse_adagrad_ref(
2183 int num_rows, // number of rows reading
2184 int block_size, // number of parameters per rows
2185 std::uint64_t param_size, // total number of parameters
2186 float* w, // input parameters
2187 const float* g, // input gradients
2188 float* h, // input momentums
2189 const std::int64_t* indices, // indices of each row
2190 float epsilon,
2191 float lr,
2192 float weight_decay,
2193 const double* counter,
2194 const int64_t counter_halflife);
2195
2196template FBGEMM_API int rowwise_sparse_adagrad_ref(
2197 int num_rows, // number of rows reading
2198 int block_size, // number of parameters per rows
2199 std::uint64_t param_size, // total number of parameters
2200 float* w, // input parameters
2201 const float* g, // input gradients
2202 float* h, // input momentums
2203 const std::int32_t* indices, // indices of each row
2204 float epsilon,
2205 float lr,
2206 float weight_decay,
2207 const double* counter,
2208 const int64_t counter_halflife);
2209
2210#define INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, OFFSET_TYPE) \
2211 template FBGEMM_API int rowwise_sparse_adagrad_fused_ref( \
2212 int64_t block_size, \
2213 int64_t output_size, \
2214 int64_t index_size, \
2215 int64_t data_size, \
2216 DATA_TYPE* w, \
2217 const float* g, \
2218 float* h, \
2219 const INDEX_TYPE* indices, \
2220 const OFFSET_TYPE* offsets_or_lengths, \
2221 float epsilon, \
2222 float lr, \
2223 bool use_offsets, \
2224 bool use_stochastic_rounding, \
2225 int emu_vector_size, \
2226 int64_t grad_stride);
2227
2228#define INSTANTIATE_SPMDM_OFFSET_T(DATA_TYPE, INDEX_TYPE) \
2229 INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, int32_t) \
2230 INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, int64_t)
2231
2232#define INSTANTIATE_SPMDM_INDEX_T(DATA_TYPE) \
2233 INSTANTIATE_SPMDM_OFFSET_T(DATA_TYPE, int32_t) \
2234 INSTANTIATE_SPMDM_OFFSET_T(DATA_TYPE, int64_t)
2235
2236INSTANTIATE_SPMDM_INDEX_T(float)
2237INSTANTIATE_SPMDM_INDEX_T(float16)
2238
2239#undef INSTANTIATE_SPMDM_OFFSET_T
2240#undef INSTANTIATE_SPMDM_BASE
2241
2242template <typename IndexType>
2243FBGEMM_API void compressed_indices_remap_ref(
2244 std::int32_t offsets_numel,
2245 const IndexType* indices,
2246 const int32_t* compressed_indices_mapping,
2247 const IndexType* offsets,
2248 const float* weights, // optional, can be null,
2249 IndexType* out_indices,
2250 IndexType* out_offsets,
2251 float* out_weights) {
2252 bool has_per_sample_weights = (weights != nullptr);
2253 out_offsets[0] = offsets[0];
2254 IndexType j = 0;
2255 for (int i = 1; i < offsets_numel; i++) {
2256 for (int32_t k = offsets[i - 1]; k < offsets[i]; k++) {
2257 if (compressed_indices_mapping[indices[k]] != -1) {
2258 out_indices[j] = compressed_indices_mapping[indices[k]];
2259 if (has_per_sample_weights) {
2260 out_weights[j] = weights[k];
2261 }
2262 j++;
2263 }
2264 }
2265 out_offsets[i] = j;
2266 }
2267}
2268
2269#define INSTANTIATE_REMAP_BASE(INDEX_TYPE) \
2270 template FBGEMM_API void compressed_indices_remap_ref( \
2271 std::int32_t offsets_numel, \
2272 const INDEX_TYPE* indices, \
2273 const int32_t* compressed_indices_mapping, \
2274 const INDEX_TYPE* offsets, \
2275 const float* weights, \
2276 INDEX_TYPE* out_indices, \
2277 INDEX_TYPE* out_offsets, \
2278 float* out_weights);
2279
2280INSTANTIATE_REMAP_BASE(int32_t)
2281INSTANTIATE_REMAP_BASE(int64_t)
2282
2283#undef INSTANTIATE_REMAP_BASE
2284
2285} // namespace fbgemm
2286