1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "./RefImplementations.h" |
9 | |
10 | #include "fbgemm/FbgemmBuild.h" |
11 | #include "fbgemm/FbgemmConvert.h" |
12 | |
13 | #include <algorithm> |
14 | #include <cassert> |
15 | #include <cmath> |
16 | #include <cstring> |
17 | #include <iostream> |
18 | #include <numeric> |
19 | #include <thread> |
20 | |
21 | using namespace std; |
22 | |
23 | namespace fbgemm { |
24 | |
25 | typedef union { |
26 | uint32_t I; |
27 | float F; |
28 | } fint32; |
29 | |
30 | // Thread-safe random number generator |
31 | // |
32 | // Return a random 32bit integer using xoshiro128++ |
33 | // http://prng.di.unimi.it/xoshiro128plusplus.c |
34 | inline uint32_t rnd128_next(int idx, int vlen) { |
35 | constexpr int VLEN_MAX = 16; // max vector size |
36 | alignas(64) static thread_local uint32_t g_rnd128_buffer[4 * VLEN_MAX]; |
37 | static thread_local bool g_rnd128_initialized = false; |
38 | |
39 | // Splitmix64: http://prng.di.unimi.it/splitmix64.c |
40 | auto rnd128_init_next = [](uint64_t& x) { |
41 | uint64_t z = (x += 0x9e3779b97f4a7c15); |
42 | z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9; |
43 | z = (z ^ (z >> 27)) * 0x94d049bb133111eb; |
44 | return z ^ (z >> 31); |
45 | }; |
46 | |
47 | auto rotl = [](const uint32_t x, int k) { |
48 | return (x << k) | (x >> (32 - k)); |
49 | }; |
50 | |
51 | if (!g_rnd128_initialized) { |
52 | // Initialize rand buffer with uniq values per thread |
53 | uint64_t h0 = std::hash<std::thread::id>{}(std::this_thread::get_id()); |
54 | for (auto i = 0; i < 4; ++i) { |
55 | // Use thread hash as seed |
56 | g_rnd128_buffer[i * VLEN_MAX] = rnd128_init_next(h0); |
57 | uint64_t h1 = g_rnd128_buffer[i * VLEN_MAX]; |
58 | for (auto v = 1; v < VLEN_MAX; ++v) { |
59 | g_rnd128_buffer[i * VLEN_MAX + v] = rnd128_init_next(h1); |
60 | } |
61 | } |
62 | g_rnd128_initialized = true; |
63 | } |
64 | |
65 | const uint32_t result = |
66 | rotl(g_rnd128_buffer[idx] + g_rnd128_buffer[3 * vlen + idx], 7) + |
67 | g_rnd128_buffer[idx]; |
68 | |
69 | const uint32_t t = g_rnd128_buffer[1 * vlen + idx] << 9; |
70 | |
71 | g_rnd128_buffer[2 * vlen + idx] ^= g_rnd128_buffer[0 * vlen + idx]; |
72 | g_rnd128_buffer[3 * vlen + idx] ^= g_rnd128_buffer[1 * vlen + idx]; |
73 | g_rnd128_buffer[1 * vlen + idx] ^= g_rnd128_buffer[2 * vlen + idx]; |
74 | g_rnd128_buffer[0 * vlen + idx] ^= g_rnd128_buffer[3 * vlen + idx]; |
75 | |
76 | g_rnd128_buffer[2 * vlen + idx] ^= t; |
77 | |
78 | g_rnd128_buffer[3 * vlen + idx] = rotl(g_rnd128_buffer[3 * vlen + idx], 11); |
79 | |
80 | return result; |
81 | } |
82 | |
83 | void FloatToFloat16_ref( |
84 | const float* src, |
85 | float16* dst, |
86 | size_t size, |
87 | bool do_clip) { |
88 | constexpr float FP16_MAX = 65504.f; |
89 | if (do_clip) { |
90 | for (size_t i = 0; i < size; i++) { |
91 | float cur_src = std::max(-FP16_MAX, std::min(src[i], FP16_MAX)); |
92 | dst[i] = cpu_float2half_rn(cur_src); |
93 | } |
94 | } else { |
95 | for (size_t i = 0; i < size; i++) { |
96 | dst[i] = cpu_float2half_rn(src[i]); |
97 | } |
98 | } |
99 | } |
100 | |
101 | void Float16ToFloat_ref(const float16* src, float* dst, size_t size) { |
102 | for (size_t i = 0; i < size; i++) { |
103 | dst[i] = cpu_half2float(src[i]); |
104 | } |
105 | } |
106 | |
107 | void FloatToBfloat16_ref(const float* src, bfloat16* dst, size_t size) { |
108 | for (size_t i = 0; i < size; i++) { |
109 | // Add 2^15 and right shift 16 to do round-nearest |
110 | dst[i] = (*reinterpret_cast<const uint32_t*>(src + i) + (1 << 15)) >> 16; |
111 | } |
112 | } |
113 | |
114 | void Bfloat16ToFloat_ref(const bfloat16* src, float* dst, size_t size) { |
115 | for (size_t i = 0; i < size; i++) { |
116 | uint32_t val_fp32 = |
117 | static_cast<uint32_t>(reinterpret_cast<const uint16_t*>(src)[i]) << 16; |
118 | reinterpret_cast<uint32_t*>(dst)[i] = val_fp32; |
119 | } |
120 | } |
121 | |
122 | void FloatToFloat8_ref( |
123 | const float input, |
124 | uint8_t* output, |
125 | int exponent_bits, |
126 | int exponent_bias) { |
127 | float max_pos = (1 << ((1 << exponent_bits) - 2 - exponent_bias)) * |
128 | (2 - std::pow(2, exponent_bits - 7)); |
129 | int mantissa_bits = 7 - exponent_bits; |
130 | fint32 val_out, bouncer, smallest_normal; |
131 | |
132 | val_out.F = input; |
133 | uint32_t sign_bit = val_out.I & 0x80000000; |
134 | val_out.I = val_out.I & 0x7FFFFFFF; |
135 | val_out.F = fminf(val_out.F, max_pos); |
136 | |
137 | smallest_normal.I = (127 - exponent_bias + 1) |
138 | << 23; // smallest hfp8 normal number in FP32 |
139 | // I don't know if the input "min_pos" is the smallest denormalized number |
140 | // or the smallest normalized number. The test below needs to be done with |
141 | // the smallest normal number, which is the numerical value 2^(1-bias) |
142 | |
143 | // The conversion for denormalized values are slightly different. HFP8 is so |
144 | // low precision that gradual underflow is probably crucial |
145 | if (val_out.F >= smallest_normal.F) { |
146 | // Use round to nearest even. We make use of the standard rounding mechanism |
147 | // in FP32 rather than rounding the mantissa and handling tie-to-even and |
148 | // incrementing exponent We want to round of 23-mbits of the FP32 value |
149 | // val_in This can be done by adding a power of 2 exactly 23-mbits larger |
150 | // than the exponent of val_in This forces val_in to be moved to the right |
151 | // and rounding exact at the location corresponding to having mbits of |
152 | // explicit mantissa left |
153 | bouncer.I = (val_out.I & 0xFF800000) + ((23 - mantissa_bits) << 23); |
154 | val_out.F = (bouncer.F + val_out.F) - bouncer.F; |
155 | // adding the bouncer rounds off bits, and subtracting bouncer |
156 | // leaves the desired value, albeit in FP32 encoding |
157 | // All we need is to change the exponent encoding to using "bias" |
158 | val_out.I = uint32_t(val_out.I - ((127 - exponent_bias) << 23)) |
159 | << (8 - exponent_bits); |
160 | val_out.I = |
161 | ((val_out.I | sign_bit) >> |
162 | 24); // the 8 lsbs is the desired HFP8 encoding |
163 | |
164 | } else { |
165 | // When the value is in the denormal range, IEEE numbers essentially becomes |
166 | // a fixed point number. The lsb is the smallest non-zero number |
167 | // 2^(1-bias-mbits) Hence, we define the bouncer so that its lsb is this |
168 | // smallest non-zero number Adding the input to this bouncer forces rounding |
169 | // to occur appropriately Also, in this situation, after adding the bouncer, |
170 | // the 8 least significant bits of the sum is already the HFP8 encoding of |
171 | // the desired result. Just need to restore the sign bit |
172 | bouncer.I = (127 + (23 + (1 - exponent_bias - mantissa_bits))) << 23; |
173 | val_out.F = bouncer.F + val_out.F; |
174 | val_out.I = val_out.I | (sign_bit >> 24); |
175 | } |
176 | |
177 | *output = val_out.I; // get the 8 lsbs |
178 | } |
179 | |
180 | void Float8ToFloat_ref( |
181 | const uint8_t input, |
182 | float* output, |
183 | int exponent_bits, |
184 | int exponent_bias) { |
185 | fint32 val_out, sign, multiplier; |
186 | |
187 | sign.I = (input & 0x80) << 24; |
188 | val_out.I = (input & 0x7F) << (24 - (8 - exponent_bits)); |
189 | // so that the mantissa bits start at the mantissa bit positions of FP32 |
190 | // encoding |
191 | |
192 | // Let the hfp8 mantissa bits correspond to the value frac, 0 <= frac < 1 |
193 | // So if the hfp8 value is a normal number, it's value is 2^e x (1+frac) |
194 | // where e is its (true, unbiased) exponent |
195 | // If the hfp8 value is denormal, the value is 2^(1-bias) x frac |
196 | |
197 | // However, the bit pattern in the 8-bit exponent field of val_out.F |
198 | // is bias+e when hfp8 is normal, and 0 when hfp8 is subnormal. |
199 | // So, as an FP32 value, when hfp8 is normal, val_out.F represents the value |
200 | // of 2^(bias+e-127) * (1+frac) |
201 | // And when hfp8 is subnormal, val_out.F is also subnormal, and represents the |
202 | // value of 2^(-126) * frac In either case, val_out.F corresponds to |
203 | // 2^(bias-127) * (value of hfp8 input) Thus, if we multiply val_out.F by |
204 | // 2^(127-bias), we obtain the hfp8 value as an FP32 number |
205 | |
206 | multiplier.I = (127 + (127 - exponent_bias)) |
207 | << 23; // multiplier.F is 2^(127-bias) |
208 | val_out.F *= multiplier.F; |
209 | val_out.I |= sign.I; |
210 | *output = val_out.F; |
211 | } |
212 | |
213 | void requantize_u8acc32_ref( |
214 | int M, |
215 | int N, |
216 | int ld, |
217 | const int32_t* inp, |
218 | uint8_t* out, |
219 | int32_t C_multiplier, |
220 | int32_t C_right_shift, |
221 | int32_t C_zero_point, |
222 | int32_t A_zero_point, |
223 | int32_t B_zero_point, |
224 | const int32_t* row_offsets, |
225 | const int32_t* col_offsets, |
226 | const int32_t* bias, |
227 | bool fuse_relu) { |
228 | int64_t nudge = 1ll << std::max(0, C_right_shift - 1); |
229 | for (int i = 0; i < M; ++i) { |
230 | for (int j = 0; j < N; ++j) { |
231 | int32_t raw = inp[i * ld + j]; |
232 | if (A_zero_point) { |
233 | raw -= A_zero_point * col_offsets[j]; |
234 | } |
235 | if (B_zero_point) { |
236 | raw -= B_zero_point * row_offsets[i]; |
237 | } |
238 | if (bias) { |
239 | raw += bias[j]; |
240 | } |
241 | |
242 | int64_t ab_64 = |
243 | static_cast<int64_t>(raw) * static_cast<int64_t>(C_multiplier); |
244 | int64_t rounded = ((ab_64 + nudge) >> C_right_shift) + C_zero_point; |
245 | |
246 | out[i * ld + j] = std::max( |
247 | fuse_relu ? static_cast<int64_t>(C_zero_point) : 0l, |
248 | std::min(static_cast<int64_t>(255l), rounded)); |
249 | } |
250 | } |
251 | } |
252 | |
253 | void requantize_u8acc32_ref( |
254 | int M, |
255 | int N, |
256 | int ld, |
257 | const int32_t* inp, |
258 | uint8_t* out, |
259 | const float* C_multiplier, |
260 | int32_t C_zero_point, |
261 | int32_t A_zero_point, |
262 | const int32_t* B_zero_point, |
263 | const int32_t* row_offsets, |
264 | const int32_t* col_offsets, |
265 | const int32_t* bias, |
266 | int ncols_per_quant_group, |
267 | bool fuse_relu) { |
268 | for (int i = 0; i < M; ++i) { |
269 | for (int j = 0; j < N; ++j) { |
270 | int32_t raw = inp[i * ld + j]; |
271 | if (A_zero_point) { |
272 | raw -= A_zero_point * col_offsets[j]; |
273 | } |
274 | raw -= B_zero_point[j / ncols_per_quant_group] * row_offsets[i]; |
275 | if (bias) { |
276 | raw += bias[j]; |
277 | } |
278 | |
279 | float result = raw * C_multiplier[j / ncols_per_quant_group]; |
280 | long rounded = lrintf(result) + C_zero_point; |
281 | out[i * ld + j] = std::max( |
282 | fuse_relu ? static_cast<long>(C_zero_point) : 0l, |
283 | std::min(255l, rounded)); |
284 | } |
285 | } |
286 | } |
287 | |
288 | void matmul_u8i8acc32_ref( |
289 | int M, |
290 | int N, |
291 | int K, |
292 | int lda, |
293 | int ldb, |
294 | int ldc, |
295 | const uint8_t* Aint8, |
296 | const int8_t* Bint8, |
297 | int32_t* Cint32) { |
298 | for (int i = 0; i < M; ++i) { |
299 | for (int j = 0; j < N; ++j) { |
300 | int32_t sum = 0; |
301 | for (int k = 0; k < K; ++k) { |
302 | sum += static_cast<int32_t>(Aint8[i * lda + k]) * |
303 | static_cast<int32_t>(Bint8[k * ldb + j]); |
304 | } |
305 | Cint32[i * ldc + j] = sum; |
306 | } |
307 | } |
308 | } |
309 | |
310 | void matmul_u8i8acc16_ref( |
311 | int M, |
312 | int N, |
313 | int K, |
314 | int lda, |
315 | int ldb, |
316 | int ldc, |
317 | int brow, |
318 | const uint8_t* Aint8, |
319 | const int8_t* Bint8, |
320 | int32_t* Cint32) { |
321 | for (int i = 0; i < M; ++i) { |
322 | for (int j = 0; j < N; ++j) { |
323 | int32_t sum = 0, sum_32bit = 0; |
324 | for (int k = 0; k < K; k += 2) { |
325 | int a0 = Aint8[i * lda + k]; |
326 | int b0 = Bint8[k * ldb + j]; |
327 | int a1 = 0, b1 = 0; |
328 | if (k + 1 < K) { |
329 | a1 = Aint8[i * lda + k + 1]; |
330 | b1 = Bint8[(k + 1) * ldb + j]; |
331 | } |
332 | sum = clip_16bit(sum + clip_16bit(a0 * b0 + a1 * b1)); |
333 | if ((k % brow) == (brow - 2)) { |
334 | sum_32bit += sum; |
335 | sum = 0; |
336 | } |
337 | } |
338 | Cint32[i * ldc + j] = sum_32bit + sum; |
339 | } |
340 | } |
341 | } |
342 | |
343 | void cblas_sgemm_ref( |
344 | const matrix_op_t transa, |
345 | const matrix_op_t transb, |
346 | const int m, |
347 | const int n, |
348 | const int k, |
349 | float alpha, |
350 | const float* Afp32, |
351 | int lda, |
352 | const float* Bfp32, |
353 | int ldb, |
354 | float beta, |
355 | float* Cfp32, |
356 | int ldc) { |
357 | for (int i = 0; i < m; ++i) { |
358 | for (int j = 0; j < n; ++j) { |
359 | float sum = 0; |
360 | for (int p = 0; p < k; ++p) { |
361 | float a = |
362 | (transa == matrix_op_t::NoTranspose ? Afp32[i * lda + p] |
363 | : Afp32[p * lda + i]); |
364 | float b = |
365 | (transb == matrix_op_t::NoTranspose ? Bfp32[p * ldb + j] |
366 | : Bfp32[j * ldb + p]); |
367 | sum += a * b; |
368 | } |
369 | if (beta == 0) { |
370 | Cfp32[i * ldc + j] = alpha * sum; |
371 | } else { |
372 | Cfp32[i * ldc + j] = alpha * sum + beta * Cfp32[i * ldc + j]; |
373 | } |
374 | } |
375 | } |
376 | } |
377 | |
378 | namespace { |
379 | // From https://stackoverflow.com/questions/31652875 |
380 | uint64_t umul64wide(uint64_t a, uint64_t b) { |
381 | uint64_t a_lo = static_cast<uint32_t>(a); |
382 | uint64_t a_hi = a >> 32; |
383 | uint64_t b_lo = static_cast<uint32_t>(b); |
384 | uint64_t b_hi = b >> 32; |
385 | |
386 | uint64_t p0 = a_lo * b_lo; |
387 | uint64_t p1 = a_lo * b_hi; |
388 | uint64_t p2 = a_hi * b_lo; |
389 | |
390 | return p0 + (p1 << 32) + (p2 << 32); |
391 | } |
392 | } // namespace |
393 | |
394 | // Expected to have overflows |
395 | NO_SANITIZE("undefined" ) |
396 | void cblas_gemm_i64_i64acc_ref( |
397 | matrix_op_t transa, |
398 | matrix_op_t transb, |
399 | int M, |
400 | int N, |
401 | int K, |
402 | const int64_t* A, |
403 | int lda, |
404 | const int64_t* B, |
405 | int ldb, |
406 | bool accumulate, |
407 | int64_t* C, |
408 | int ldc) { |
409 | for (int i = 0; i < M; ++i) { |
410 | for (int j = 0; j < N; ++j) { |
411 | int64_t acc; |
412 | if (accumulate) { |
413 | acc = C[i * ldc + j]; |
414 | } else { |
415 | acc = 0; |
416 | } |
417 | for (int k = 0; k < K; ++k) { |
418 | int64_t a = |
419 | A[transa == matrix_op_t::Transpose ? i + k * lda : i * lda + k]; |
420 | int64_t b = |
421 | B[transb == matrix_op_t::Transpose ? k + j * ldb : k * ldb + j]; |
422 | int64_t lo = umul64wide(a, b); |
423 | acc += lo; |
424 | } |
425 | C[i * ldc + j] = acc; |
426 | } // j |
427 | } // i |
428 | } |
429 | |
430 | void row_offsets_u8acc32_ref( |
431 | int M, |
432 | int K, |
433 | int ld, |
434 | const uint8_t* Aint8, |
435 | int32_t* row_offsets) { |
436 | // row offset |
437 | for (int i = 0; i < M; ++i) { |
438 | int32_t sum = 0; |
439 | for (int k = 0; k < K; ++k) { |
440 | sum += static_cast<int32_t>(Aint8[i * ld + k]); |
441 | } |
442 | row_offsets[i] = sum; |
443 | } |
444 | } |
445 | |
446 | void col_offsets_with_zero_pt_s8acc32_ref( |
447 | int K, |
448 | int N, |
449 | int ld, |
450 | const int8_t* Bint8, |
451 | const int32_t* B_zero_point, |
452 | int32_t* col_offsets, |
453 | int ncols_per_quant_group) { |
454 | for (int j = 0; j < N; ++j) { |
455 | int32_t sum = 0; |
456 | for (int k = 0; k < K; ++k) { |
457 | sum += Bint8[k * ld + j]; |
458 | } |
459 | col_offsets[j] = sum - B_zero_point[j / ncols_per_quant_group] * K; |
460 | } |
461 | } |
462 | |
463 | void spmdm_ref( |
464 | int M, |
465 | const uint8_t* A, |
466 | int lda, |
467 | fbgemm::CompressedSparseColumn& B, |
468 | bool accumulation, |
469 | int32_t* C, |
470 | int ldc, |
471 | int groups /*=1*/) { |
472 | int N = B.NumOfCols(); |
473 | assert(N % groups == 0); |
474 | if (!accumulation) { |
475 | for (int i = 0; i < M; ++i) { |
476 | for (int j = 0; j < N; ++j) { |
477 | C[i * ldc + j] = 0; |
478 | } |
479 | } |
480 | } |
481 | for (int g = 0; g < groups; ++g) { |
482 | for (int j = g * (N / groups); j < (g + 1) * (N / groups); ++j) { |
483 | for (int k = B.ColPtr()[j]; k < B.ColPtr()[j + 1]; ++k) { |
484 | int row = g * B.NumOfRows() + B.RowIdx()[k]; |
485 | int w = B.Values()[k]; |
486 | for (int i = 0; i < M; ++i) { |
487 | C[i * ldc + j] += A[i * lda + row] * w; |
488 | } |
489 | } |
490 | } // for each column of B |
491 | } // for each group |
492 | } |
493 | |
494 | int32_t clip_16bit(int32_t x) { |
495 | if (x > numeric_limits<int16_t>::max()) { |
496 | return std::min<int>(numeric_limits<int16_t>::max(), x); |
497 | } else if (x < numeric_limits<int16_t>::min()) { |
498 | return std::max<int>(numeric_limits<int16_t>::min(), x); |
499 | } else { |
500 | return x; |
501 | } |
502 | } |
503 | |
504 | /* Imitate the Im2Col<float, CPUContext, StorageOrder::NWC> function |
505 | * from caffe2/utils/math_cpu.cc |
506 | * NWC StorageOrder/Layout |
507 | * A: NWC: NW_0 x C_0 |
508 | * Ao: NWC: NW_1 x G KW C_0/G |
509 | */ |
510 | template <> |
511 | FBGEMM_API void im2col_ref( |
512 | const conv_param_t<1>& conv_p, |
513 | const uint8_t* A, |
514 | int32_t A_zero_point, |
515 | uint8_t* Ao) { |
516 | int IC = conv_p.IC; |
517 | int G = conv_p.G; |
518 | assert(IC % G == 0); |
519 | array<int, 1> IN_DIM = conv_p.IN_DIM; |
520 | array<int, 1> OUT_DIM = conv_p.OUT_DIM; |
521 | array<int, 1> K = conv_p.K; |
522 | |
523 | if (conv_p.transposed) { |
524 | for (int n = 0; n < conv_p.MB; ++n) { |
525 | for (int ow = 0; ow < OUT_DIM[0]; ++ow) { |
526 | for (int s = 0; s < K[0]; ++s) { |
527 | int w = ow + conv_p.pad[0] - s * conv_p.dilation[0]; |
528 | int w_in = w / conv_p.stride[0]; |
529 | if (w_in * conv_p.stride[0] == w && w_in >= 0 && w_in < IN_DIM[0]) { |
530 | for (int g = 0; g < G; ++g) { |
531 | memcpy( |
532 | Ao + (((n * OUT_DIM[0] + ow) * G + g) * K[0] + s) * (IC / G), |
533 | A + (n * IN_DIM[0] + w_in) * IC + g * (IC / G), |
534 | sizeof(uint8_t) * (IC / G)); |
535 | } |
536 | } else { |
537 | for (int g = 0; g < G; ++g) { |
538 | memset( |
539 | Ao + (((n * OUT_DIM[0] + ow) * G + g) * K[0] + s) * (IC / G), |
540 | A_zero_point, |
541 | sizeof(uint8_t) * (IC / G)); |
542 | } |
543 | } |
544 | } // for each s |
545 | } // for each ow |
546 | } // for each n |
547 | } else { |
548 | for (int n = 0; n < conv_p.MB; ++n) { |
549 | for (int w = 0; w < OUT_DIM[0]; ++w) { |
550 | for (int s = 0; s < K[0]; ++s) { |
551 | int w_in = |
552 | -conv_p.pad[0] + w * conv_p.stride[0] + s * conv_p.dilation[0]; |
553 | if (w_in < 0 || w_in >= IN_DIM[0]) { |
554 | for (int g = 0; g < G; ++g) { |
555 | memset( |
556 | Ao + (((n * OUT_DIM[0] + w) * G + g) * K[0] + s) * (IC / G), |
557 | A_zero_point, |
558 | sizeof(uint8_t) * (IC / G)); |
559 | } |
560 | } else { |
561 | for (int g = 0; g < G; ++g) { |
562 | memcpy( |
563 | Ao + (((n * OUT_DIM[0] + w) * G + g) * K[0] + s) * (IC / G), |
564 | A + (n * IN_DIM[0] + w_in) * IC + g * (IC / G), |
565 | sizeof(uint8_t) * (IC / G)); |
566 | } |
567 | } |
568 | } // for each s |
569 | } // for each w |
570 | } // for each n |
571 | } |
572 | } |
573 | |
574 | /* Imitate the Im2Col<float, CPUContext, StorageOrder::NHWC> function |
575 | * from caffe2/utils/math_cpu.cc |
576 | * NHWC StorageOrder/Layout |
577 | * A: NHWC: NH_0W_0 x C_0 |
578 | * Ao: NHWC: NH_1W_1 x G RS C_0/G |
579 | */ |
580 | template <> |
581 | FBGEMM_API void im2col_ref( |
582 | const conv_param_t<2>& conv_p, |
583 | const uint8_t* A, |
584 | int32_t A_zero_point, |
585 | uint8_t* Ao) { |
586 | int IC = conv_p.IC; |
587 | int G = conv_p.G; |
588 | assert(IC % G == 0); |
589 | array<int, 2> IN_DIM = conv_p.IN_DIM; |
590 | array<int, 2> OUT_DIM = conv_p.OUT_DIM; |
591 | array<int, 2> K = conv_p.K; |
592 | |
593 | if (conv_p.transposed) { |
594 | for (int n = 0; n < conv_p.MB; ++n) { |
595 | for (int oh = 0; oh < OUT_DIM[0]; ++oh) { |
596 | for (int ow = 0; ow < OUT_DIM[1]; ++ow) { |
597 | for (int r = 0; r < K[0]; ++r) { |
598 | for (int s = 0; s < K[1]; ++s) { |
599 | int h = oh + conv_p.pad[0] - r * conv_p.dilation[0]; |
600 | int w = ow + conv_p.pad[1] - s * conv_p.dilation[1]; |
601 | int h_in = h / conv_p.stride[0]; |
602 | int w_in = w / conv_p.stride[1]; |
603 | if (h_in * conv_p.stride[0] == h && h_in >= 0 && |
604 | h_in < IN_DIM[0] && w_in * conv_p.stride[1] == w && |
605 | w_in >= 0 && w_in < IN_DIM[1]) { |
606 | for (int g = 0; g < G; ++g) { |
607 | memcpy( |
608 | Ao + |
609 | (((((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * G + |
610 | g) * |
611 | K[0] + |
612 | r) * |
613 | K[1] + |
614 | s) * |
615 | (IC / G), |
616 | A + ((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC + |
617 | g * (IC / G), |
618 | sizeof(uint8_t) * (IC / G)); |
619 | } |
620 | } else { |
621 | for (int g = 0; g < G; ++g) { |
622 | memset( |
623 | Ao + |
624 | (((((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * G + |
625 | g) * |
626 | K[0] + |
627 | r) * |
628 | K[1] + |
629 | s) * |
630 | (IC / G), |
631 | A_zero_point, |
632 | sizeof(uint8_t) * (IC / G)); |
633 | } |
634 | } |
635 | } // for each s |
636 | } // for each r |
637 | } // for each ow |
638 | } // for each oh |
639 | } // for each n |
640 | } else { |
641 | for (int n = 0; n < conv_p.MB; ++n) { |
642 | for (int h = 0; h < OUT_DIM[0]; ++h) { |
643 | for (int w = 0; w < OUT_DIM[1]; ++w) { |
644 | for (int r = 0; r < K[0]; ++r) { |
645 | int h_in = |
646 | -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0]; |
647 | for (int s = 0; s < K[1]; ++s) { |
648 | int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + |
649 | s * conv_p.dilation[1]; |
650 | if (h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 || |
651 | w_in >= IN_DIM[1]) { |
652 | for (int g = 0; g < G; ++g) { |
653 | memset( |
654 | Ao + |
655 | (((((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * G + g) * |
656 | K[0] + |
657 | r) * |
658 | K[1] + |
659 | s) * |
660 | (IC / G), |
661 | A_zero_point, |
662 | sizeof(uint8_t) * (IC / G)); |
663 | } |
664 | } else { |
665 | for (int g = 0; g < G; ++g) { |
666 | memcpy( |
667 | Ao + |
668 | (((((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * G + g) * |
669 | K[0] + |
670 | r) * |
671 | K[1] + |
672 | s) * |
673 | (IC / G), |
674 | A + ((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC + |
675 | g * (IC / G), |
676 | sizeof(uint8_t) * (IC / G)); |
677 | } |
678 | } |
679 | } // for each s |
680 | } // for each r |
681 | } // for each w |
682 | } // for each h |
683 | } // for each n |
684 | } |
685 | } |
686 | |
687 | /* Imitate the Im2Col<float, CPUContext, StorageOrder::NHWC> function |
688 | * from caffe2/utils/math_cpu.cc |
689 | * NHWC StorageOrder/Layout |
690 | * A: NHWC: NT_0H_0W_0 x C_0 |
691 | * Ao: NHWC: NT_1H_1W_1 x G QRS C_0/G |
692 | */ |
693 | template <> |
694 | FBGEMM_API void im2col_ref( |
695 | const conv_param_t<3>& conv_p, |
696 | const uint8_t* A, |
697 | int32_t A_zero_point, |
698 | uint8_t* Ao) { |
699 | int IC = conv_p.IC; |
700 | int G = conv_p.G; |
701 | assert(IC % G == 0); |
702 | array<int, 3> IN_DIM = conv_p.IN_DIM; |
703 | array<int, 3> OUT_DIM = conv_p.OUT_DIM; |
704 | array<int, 3> K = conv_p.K; |
705 | |
706 | if (conv_p.transposed) { |
707 | for (int n = 0; n < conv_p.MB; ++n) { |
708 | for (int ot = 0; ot < OUT_DIM[0]; ++ot) { |
709 | for (int oh = 0; oh < OUT_DIM[1]; ++oh) { |
710 | for (int ow = 0; ow < OUT_DIM[2]; ++ow) { |
711 | for (int q = 0; q < K[0]; ++q) { |
712 | for (int r = 0; r < K[1]; ++r) { |
713 | for (int s = 0; s < K[2]; ++s) { |
714 | int t = ot + conv_p.pad[0] - q * conv_p.dilation[0]; |
715 | int h = oh + conv_p.pad[1] - r * conv_p.dilation[1]; |
716 | int w = ow + conv_p.pad[2] - s * conv_p.dilation[2]; |
717 | int t_in = t / conv_p.stride[0]; |
718 | int h_in = h / conv_p.stride[1]; |
719 | int w_in = w / conv_p.stride[2]; |
720 | if (t_in * conv_p.stride[0] == t && t_in >= 0 && |
721 | t_in < IN_DIM[0] && h_in * conv_p.stride[1] == h && |
722 | h_in >= 0 && h_in < IN_DIM[1] && |
723 | w_in * conv_p.stride[2] == w && w_in >= 0 && |
724 | w_in < IN_DIM[2]) { |
725 | for (int g = 0; g < G; ++g) { |
726 | memcpy( |
727 | Ao + |
728 | (((((((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) * |
729 | OUT_DIM[2] + |
730 | ow) * |
731 | G + |
732 | g) * |
733 | K[0] + |
734 | q) * |
735 | K[1] + |
736 | r) * |
737 | K[2] + |
738 | s) * |
739 | (IC / G), |
740 | A + |
741 | (((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) * |
742 | IN_DIM[2] + |
743 | w_in) * |
744 | IC + |
745 | g * (IC / G), |
746 | sizeof(uint8_t) * (IC / G)); |
747 | } |
748 | } else { |
749 | for (int g = 0; g < G; ++g) { |
750 | memset( |
751 | Ao + |
752 | (((((((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) * |
753 | OUT_DIM[2] + |
754 | ow) * |
755 | G + |
756 | g) * |
757 | K[0] + |
758 | q) * |
759 | K[1] + |
760 | r) * |
761 | K[2] + |
762 | s) * |
763 | (IC / G), |
764 | A_zero_point, |
765 | sizeof(uint8_t) * (IC / G)); |
766 | } |
767 | } |
768 | } // for each s |
769 | } // for each r |
770 | } // for each q |
771 | } // for each ow |
772 | } // for each oh |
773 | } // for each ot |
774 | } // for each n |
775 | } else { |
776 | for (int n = 0; n < conv_p.MB; ++n) { |
777 | for (int t = 0; t < OUT_DIM[0]; ++t) { |
778 | for (int h = 0; h < OUT_DIM[1]; ++h) { |
779 | for (int w = 0; w < OUT_DIM[2]; ++w) { |
780 | for (int q = 0; q < K[0]; ++q) { |
781 | int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + |
782 | q * conv_p.dilation[0]; |
783 | for (int r = 0; r < K[1]; ++r) { |
784 | int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + |
785 | r * conv_p.dilation[1]; |
786 | for (int s = 0; s < K[2]; ++s) { |
787 | int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + |
788 | s * conv_p.dilation[2]; |
789 | if (t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 || |
790 | h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2]) { |
791 | for (int g = 0; g < G; ++g) { |
792 | memset( |
793 | Ao + |
794 | (((((((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) * |
795 | OUT_DIM[2] + |
796 | w) * |
797 | G + |
798 | g) * |
799 | K[0] + |
800 | q) * |
801 | K[1] + |
802 | r) * |
803 | K[2] + |
804 | s) * |
805 | (IC / G), |
806 | A_zero_point, |
807 | sizeof(uint8_t) * (IC / G)); |
808 | } |
809 | } else { |
810 | for (int g = 0; g < G; ++g) { |
811 | memcpy( |
812 | Ao + |
813 | (((((((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) * |
814 | OUT_DIM[2] + |
815 | w) * |
816 | G + |
817 | g) * |
818 | K[0] + |
819 | q) * |
820 | K[1] + |
821 | r) * |
822 | K[2] + |
823 | s) * |
824 | (IC / G), |
825 | A + |
826 | (((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) * |
827 | IN_DIM[2] + |
828 | w_in) * |
829 | IC + |
830 | g * (IC / G), |
831 | sizeof(uint8_t) * (IC / G)); |
832 | } |
833 | } |
834 | } // for each s |
835 | } // for each r |
836 | } // for each q |
837 | } // for each w |
838 | } // for each h |
839 | } // for each t |
840 | } // for each n |
841 | } |
842 | } |
843 | |
844 | // 1D Conv |
845 | template <> |
846 | FBGEMM_API void conv_ref( |
847 | const conv_param_t<1>& conv_p, |
848 | const uint8_t* A, |
849 | int32_t A_zero_point, |
850 | const int8_t* B, |
851 | int32_t* C) { |
852 | // A is assumed to be (N Lin Cin) |
853 | // B is assumed to be (G K Cin/G Cout/G) |
854 | // C is assumed to be (N Lout Cout) |
855 | int IC = conv_p.IC; |
856 | int OC = conv_p.OC; |
857 | int G = conv_p.G; |
858 | assert(IC % G == 0); |
859 | assert(OC % G == 0); |
860 | array<int, 1> IN_DIM = conv_p.IN_DIM; |
861 | array<int, 1> OUT_DIM = conv_p.OUT_DIM; |
862 | array<int, 1> K = conv_p.K; |
863 | |
864 | if (conv_p.transposed) { |
865 | // for ref implementation, there is no padding on the input buffer, |
866 | // padding specifies how much we remove from the output buffers |
867 | for (int n = 0; n < conv_p.MB; ++n) { |
868 | for (int ow = 0; ow < OUT_DIM[0]; ++ow) { |
869 | // stride on output is fractional stride on input |
870 | // conv index is |
871 | // int w_in = -conv_p.pad[0] + w* conv_p.stride[0] + r* |
872 | // conv_p.dilation[0]; |
873 | // so we reverse it |
874 | for (int g = 0; g < G; ++g) { |
875 | for (int oc = 0; oc < OC / G; ++oc) { |
876 | int sum = 0; |
877 | for (int r = 0; r < K[0]; ++r) { |
878 | int w = ow + conv_p.pad[0] - r * conv_p.dilation[0]; |
879 | int w_in = w / conv_p.stride[0]; |
880 | for (int ic = 0; ic < IC / G; ++ic) { |
881 | int a = (w_in * conv_p.stride[0] == w && w_in >= 0 && |
882 | w_in < IN_DIM[0]) |
883 | ? A[(n * IN_DIM[0] + w_in) * IC + g * (IC / G) + ic] |
884 | : A_zero_point; |
885 | int b = |
886 | B[((g * K[0] + r) * IC / G + ic) * (OC / G) + |
887 | oc]; // G K IC/G OC/G after transpose |
888 | sum += a * b; |
889 | } // for each ic |
890 | } // for each r |
891 | C[(n * OUT_DIM[0] + ow) * OC + g * (OC / G) + oc] = sum; |
892 | } // for each oc |
893 | } // for each g |
894 | } // for each w |
895 | } // for each n |
896 | } else { |
897 | for (int n = 0; n < conv_p.MB; ++n) { |
898 | for (int w = 0; w < OUT_DIM[0]; ++w) { |
899 | for (int g = 0; g < G; ++g) { |
900 | for (int m = 0; m < OC / G; ++m) { |
901 | int sum = 0; |
902 | for (int r = 0; r < K[0]; ++r) { |
903 | int w_in = -conv_p.pad[0] + w * conv_p.stride[0] + |
904 | r * conv_p.dilation[0]; |
905 | for (int c = 0; c < IC / G; ++c) { |
906 | int a = w_in < 0 || w_in >= IN_DIM[0] |
907 | ? A_zero_point |
908 | : A[(n * IN_DIM[0] + w_in) * IC + g * (IC / G) + c]; |
909 | int b = |
910 | B[((g * K[0] + r) * (IC / G) + c) * (OC / G) + |
911 | m]; // G K IC/G OC/G after transpose |
912 | sum += a * b; |
913 | } // for each c |
914 | } // for each r |
915 | C[(n * OUT_DIM[0] + w) * OC + g * (OC / G) + m] = sum; |
916 | } // for each w |
917 | } // for each m |
918 | } // for each group |
919 | } // for each n |
920 | } |
921 | } |
922 | |
923 | // 2D Conv |
924 | template <> |
925 | FBGEMM_API void conv_ref( |
926 | const conv_param_t<2>& conv_p, |
927 | const uint8_t* A, |
928 | int32_t A_zero_point, |
929 | const int8_t* B, |
930 | int32_t* C) { |
931 | // filters are assumed to be in G RS C/G x K format |
932 | int IC = conv_p.IC; |
933 | int OC = conv_p.OC; |
934 | int G = conv_p.G; |
935 | assert(IC % G == 0); |
936 | assert(OC % G == 0); |
937 | array<int, 2> IN_DIM = conv_p.IN_DIM; |
938 | array<int, 2> OUT_DIM = conv_p.OUT_DIM; |
939 | array<int, 2> K = conv_p.K; |
940 | |
941 | if (conv_p.transposed) { |
942 | // for ref implementation, there is no padding on the input buffer, |
943 | // padding specifies how much we remove from the output buffers |
944 | for (int n = 0; n < conv_p.MB; ++n) { |
945 | for (int oh = 0; oh < OUT_DIM[0]; ++oh) { |
946 | for (int ow = 0; ow < OUT_DIM[1]; ++ow) { |
947 | // stride on output is fractional stride on input |
948 | // conv index is |
949 | // int h_in = |
950 | // -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0]; |
951 | // int w_in = |
952 | // -conv_p.pad[1] + w * conv_p.stride[1] + s * conv_p.dilation[1]; |
953 | // so we reverse it |
954 | for (int g = 0; g < G; ++g) { |
955 | for (int oc = 0; oc < OC / G; ++oc) { |
956 | int sum = 0; |
957 | for (int r = 0; r < K[0]; ++r) { |
958 | for (int s = 0; s < K[1]; ++s) { |
959 | int h = oh + conv_p.pad[0] - r * conv_p.dilation[0]; |
960 | int w = ow + conv_p.pad[1] - s * conv_p.dilation[1]; |
961 | int h_in = h / conv_p.stride[0]; |
962 | int w_in = w / conv_p.stride[1]; |
963 | for (int ic = 0; ic < IC / G; ++ic) { |
964 | int a = (h_in * conv_p.stride[0] == h && h_in >= 0 && |
965 | h_in < IN_DIM[0] && w_in * conv_p.stride[1] == w && |
966 | w_in >= 0 && w_in < IN_DIM[1]) |
967 | ? A[((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC + |
968 | g * (IC / G) + ic] |
969 | : A_zero_point; |
970 | int b = |
971 | B[((((g * K[0] + r) * K[1] + s) * (IC / G) + ic) * OC / |
972 | G) + |
973 | oc]; // G R S IC OC after transpose |
974 | sum += a * b; |
975 | } // for each ic |
976 | } // for each s |
977 | } // for each r |
978 | C[((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * OC + g * (OC / G) + |
979 | oc] = sum; |
980 | } // for each oc |
981 | } // for each g |
982 | } // for each w |
983 | } // for each h |
984 | } // for each n |
985 | } else { |
986 | for (int n = 0; n < conv_p.MB; ++n) { |
987 | for (int h = 0; h < OUT_DIM[0]; ++h) { |
988 | for (int w = 0; w < OUT_DIM[1]; ++w) { |
989 | for (int g = 0; g < G; ++g) { |
990 | for (int m = 0; m < OC / G; ++m) { |
991 | int sum = 0; |
992 | for (int r = 0; r < K[0]; ++r) { |
993 | int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + |
994 | r * conv_p.dilation[0]; |
995 | for (int s = 0; s < K[1]; ++s) { |
996 | int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + |
997 | s * conv_p.dilation[1]; |
998 | for (int c = 0; c < IC / G; ++c) { |
999 | int a = h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 || |
1000 | w_in >= IN_DIM[1] |
1001 | ? A_zero_point |
1002 | : A[((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC + |
1003 | g * (IC / G) + c]; |
1004 | int b = |
1005 | B[(((g * K[0] + r) * K[1] + s) * (IC / G) + c) * |
1006 | (OC / G) + |
1007 | m]; |
1008 | sum += a * b; |
1009 | } // for each c |
1010 | } // for each s |
1011 | } // for each r |
1012 | C[((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * OC + g * (OC / G) + |
1013 | m] = sum; |
1014 | } // for each m |
1015 | } // for each group |
1016 | } // for each w |
1017 | } // for each h |
1018 | } // for each n |
1019 | } |
1020 | } |
1021 | |
1022 | // 3D Conv |
1023 | template <> |
1024 | FBGEMM_API void conv_ref( |
1025 | const conv_param_t<3>& conv_p, |
1026 | const uint8_t* A, |
1027 | int32_t A_zero_point, |
1028 | const int8_t* B, |
1029 | int32_t* C) { |
1030 | // filters are assumed to be in G QRS C/G x K format |
1031 | int IC = conv_p.IC; |
1032 | int OC = conv_p.OC; |
1033 | int G = conv_p.G; |
1034 | assert(IC % G == 0); |
1035 | assert(OC % G == 0); |
1036 | array<int, 3> IN_DIM = conv_p.IN_DIM; |
1037 | array<int, 3> OUT_DIM = conv_p.OUT_DIM; |
1038 | array<int, 3> K = conv_p.K; |
1039 | |
1040 | if (conv_p.transposed) { |
1041 | // for ref implementation, there is no padding on the input buffer, |
1042 | // padding specifies how much we remove from the output buffers |
1043 | for (int n = 0; n < conv_p.MB; ++n) { |
1044 | for (int ot = 0; ot < OUT_DIM[0]; ++ot) { |
1045 | for (int oh = 0; oh < OUT_DIM[1]; ++oh) { |
1046 | for (int ow = 0; ow < OUT_DIM[2]; ++ow) { |
1047 | // stride on output is fractional stride on input |
1048 | // conv index is |
1049 | // int t_in = |
1050 | // -conv_p.pad[0] + t * conv_p.stride[0] + q * |
1051 | // conv_p.dilation[0]; |
1052 | // int h_in = |
1053 | // -conv_p.pad[1] + h * conv_p.stride[1] + r * |
1054 | // conv_p.dilation[1]; |
1055 | // int w_in = |
1056 | // -conv_p.pad[2] + w * conv_p.stride[2] + s * |
1057 | // conv_p.dilation[2]; |
1058 | // so we reverse it |
1059 | for (int g = 0; g < G; ++g) { |
1060 | for (int oc = 0; oc < OC / G; ++oc) { |
1061 | int sum = 0; |
1062 | for (int q = 0; q < K[0]; ++q) { |
1063 | for (int r = 0; r < K[1]; ++r) { |
1064 | for (int s = 0; s < K[2]; ++s) { |
1065 | int t = ot + conv_p.pad[0] - q * conv_p.dilation[0]; |
1066 | int h = oh + conv_p.pad[1] - r * conv_p.dilation[1]; |
1067 | int w = ow + conv_p.pad[2] - s * conv_p.dilation[2]; |
1068 | int t_in = t / conv_p.stride[0]; |
1069 | int h_in = h / conv_p.stride[1]; |
1070 | int w_in = w / conv_p.stride[2]; |
1071 | for (int ic = 0; ic < IC / G; ++ic) { |
1072 | int a = |
1073 | (t_in * conv_p.stride[0] == t && t_in >= 0 && |
1074 | t_in < IN_DIM[0] && h_in * conv_p.stride[1] == h && |
1075 | h_in >= 0 && h_in < IN_DIM[1] && |
1076 | w_in * conv_p.stride[2] == w && w_in >= 0 && |
1077 | w_in < IN_DIM[2]) |
1078 | ? A[((((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) * |
1079 | IN_DIM[2]) + |
1080 | w_in) * |
1081 | IC + |
1082 | g * (IC / G) + ic] |
1083 | : A_zero_point; |
1084 | int b = |
1085 | B[((((((g * K[0] + q)) * K[1] + r) * K[2] + s) * |
1086 | (IC / G) + |
1087 | ic) * |
1088 | (OC / G)) + |
1089 | oc]; // G Q R S Cin/G Cout/G after transpose |
1090 | sum += a * b; |
1091 | } // for each ic |
1092 | } // for each s |
1093 | } // for each r |
1094 | } // for each q |
1095 | C[(((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) * OUT_DIM[2] + |
1096 | ow) * |
1097 | OC + |
1098 | g * (OC / G) + oc] = sum; |
1099 | } // for each oc |
1100 | } // for each g |
1101 | } // for each ow |
1102 | } // for each oh |
1103 | } // for each ot |
1104 | } // for each n |
1105 | } else { |
1106 | for (int n = 0; n < conv_p.MB; ++n) { |
1107 | for (int t = 0; t < OUT_DIM[0]; ++t) { |
1108 | for (int h = 0; h < OUT_DIM[1]; ++h) { |
1109 | for (int w = 0; w < OUT_DIM[2]; ++w) { |
1110 | for (int g = 0; g < G; ++g) { |
1111 | for (int m = 0; m < OC / G; ++m) { |
1112 | int sum = 0; |
1113 | for (int q = 0; q < K[0]; ++q) { |
1114 | int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + |
1115 | q * conv_p.dilation[0]; |
1116 | for (int r = 0; r < K[1]; ++r) { |
1117 | int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + |
1118 | r * conv_p.dilation[1]; |
1119 | for (int s = 0; s < K[2]; ++s) { |
1120 | int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + |
1121 | s * conv_p.dilation[2]; |
1122 | for (int c = 0; c < IC / G; ++c) { |
1123 | int a = t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 || |
1124 | h_in >= IN_DIM[1] || w_in < 0 || |
1125 | w_in >= IN_DIM[2] |
1126 | ? A_zero_point |
1127 | : A[(((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) * |
1128 | IN_DIM[2] + |
1129 | w_in) * |
1130 | IC + |
1131 | g * (IC / G) + c]; |
1132 | int b = |
1133 | B[((((g * K[0] + q) * K[1] + r) * K[2] + s) * |
1134 | (IC / G) + |
1135 | c) * |
1136 | (OC / G) + |
1137 | m]; |
1138 | sum += a * b; |
1139 | } // for each c |
1140 | } // for each s |
1141 | } // for each r |
1142 | } // for each q |
1143 | C[(((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) * OUT_DIM[2] + w) * |
1144 | OC + |
1145 | g * (OC / G) + m] = sum; |
1146 | } // for each m |
1147 | } // for each group |
1148 | } // for each w |
1149 | } // for each h |
1150 | } // for each t |
1151 | } // for each n |
1152 | } |
1153 | } |
1154 | |
1155 | template <int SPATIAL_DIM> |
1156 | void transposeConvWeights( |
1157 | const conv_param_t<SPATIAL_DIM>& conv_p, |
1158 | const std::int8_t* src, |
1159 | std::int8_t* dest) { |
1160 | int G = conv_p.G; |
1161 | int IC_per_G = conv_p.IC / conv_p.G; |
1162 | int OC_per_G = conv_p.OC / conv_p.G; |
1163 | |
1164 | int filter_prod = std::accumulate( |
1165 | conv_p.K.begin(), |
1166 | conv_p.K.begin() + SPATIAL_DIM, |
1167 | 1, |
1168 | std::multiplies<int>()); |
1169 | // Transforms weights from G K/G (T R S C/G) to G (T R S C/G) K/G format. |
1170 | for (int g = 0; g < G; ++g) { |
1171 | for (int k = 0; k < OC_per_G; ++k) { |
1172 | for (int f = 0; f < filter_prod; ++f) { |
1173 | for (int c = 0; c < IC_per_G; ++c) { |
1174 | dest[((g * filter_prod + f) * IC_per_G + c) * OC_per_G + k] = |
1175 | src[((g * OC_per_G + k) * filter_prod + f) * IC_per_G + c]; |
1176 | } |
1177 | } |
1178 | } |
1179 | } |
1180 | } |
1181 | |
1182 | template float convert_to_float_ref(float src, bool is_bf16); |
1183 | template float convert_to_float_ref(uint16_t src, bool is_bf16); |
1184 | template float convert_from_float_ref(float src, bool is_bf16); |
1185 | template uint16_t convert_from_float_ref(float bfloat16, bool is_bf16); |
1186 | |
1187 | template < |
1188 | typename InType, |
1189 | typename IndexType, |
1190 | typename OffsetType, |
1191 | typename OutType> |
1192 | bool EmbeddingSpMDM_ref( |
1193 | const int64_t block_size, |
1194 | const int64_t output_size, |
1195 | const int64_t index_size, |
1196 | const int64_t data_size, |
1197 | const InType* input, |
1198 | const IndexType* indices, |
1199 | const OffsetType* offsets_or_lengths, |
1200 | const float* weights, // optional, can be null for non-weighted sum |
1201 | bool normalize_by_lengths, |
1202 | OutType* out, |
1203 | bool is_weight_positional, |
1204 | bool use_offsets, |
1205 | int64_t output_stride /*=-1*/, |
1206 | int64_t input_stride /*=-1*/, |
1207 | bool scale_bias_last, |
1208 | bool no_bag, |
1209 | bool is_bf16 /*=false*/) { |
1210 | bool is8bit = is_same<InType, uint8_t>::value; |
1211 | if (output_stride == -1) { |
1212 | output_stride = block_size; |
1213 | } |
1214 | |
1215 | vector<float> buf(block_size); |
1216 | |
1217 | if (is8bit) { |
1218 | // block_size is the number of elements and fused_block_size is the size of |
1219 | // an entire row, including scale and bias. |
1220 | if (input_stride == -1) { |
1221 | // scale_bias_last == false is for table batched embedding that stores |
1222 | // scale and bias in float16 |
1223 | const auto scale_bias_offset = |
1224 | 2 * (scale_bias_last ? sizeof(float) : sizeof(float16)); |
1225 | input_stride = block_size + scale_bias_offset; |
1226 | } |
1227 | int64_t current = 0; |
1228 | |
1229 | if (no_bag) { |
1230 | for (int m = 0; m < output_size; ++m) { |
1231 | memset(buf.data(), 0, sizeof(float) * block_size); |
1232 | int64_t idx = indices[m]; |
1233 | |
1234 | if (idx < 0 || idx >= data_size) { |
1235 | return false; |
1236 | } |
1237 | |
1238 | const float* scale_bias = reinterpret_cast<const float*>( |
1239 | input + input_stride * idx + (scale_bias_last ? block_size : 0)); |
1240 | |
1241 | float weight = 1.0f; |
1242 | if (weights) { |
1243 | weight = weights[m]; |
1244 | } |
1245 | |
1246 | float scale, bias; |
1247 | if (scale_bias_last) { |
1248 | scale = weight * scale_bias[0]; |
1249 | bias = weight * scale_bias[1]; |
1250 | } else { |
1251 | scale = weight * |
1252 | cpu_half2float(reinterpret_cast<const float16*>(scale_bias)[0]); |
1253 | bias = weight * |
1254 | cpu_half2float(reinterpret_cast<const float16*>(scale_bias)[1]); |
1255 | } |
1256 | |
1257 | for (int j = 0; j < block_size; ++j) { |
1258 | buf[j] = std::fma( |
1259 | scale, |
1260 | input |
1261 | [input_stride * idx + j + |
1262 | (scale_bias_last ? 0 : 2 * sizeof(float16))], |
1263 | buf[j] + bias); |
1264 | } |
1265 | for (int j = 0; j < block_size; ++j) { |
1266 | out[j] = is_same<OutType, float16>::value ? cpu_float2half_rn(buf[j]) |
1267 | : buf[j]; |
1268 | } |
1269 | out += output_stride; |
1270 | } // m |
1271 | return true; |
1272 | } // no_bag |
1273 | |
1274 | for (int m = 0; m < output_size; ++m) { |
1275 | memset(buf.data(), 0, sizeof(float) * block_size); |
1276 | int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] |
1277 | : offsets_or_lengths[m]; |
1278 | if (current + len > index_size) { |
1279 | return false; |
1280 | } |
1281 | for (int i = 0; i < len; ++i) { |
1282 | int64_t idx = indices[current]; |
1283 | if (idx < 0 || idx >= data_size) { |
1284 | return false; |
1285 | } |
1286 | |
1287 | const float* scale_bias = reinterpret_cast<const float*>( |
1288 | input + input_stride * idx + (scale_bias_last ? block_size : 0)); |
1289 | |
1290 | float weight = 1.0f; |
1291 | if (weights) { |
1292 | weight = weights[is_weight_positional ? i : current]; |
1293 | } |
1294 | float scale, bias; |
1295 | if (scale_bias_last) { |
1296 | scale = weight * scale_bias[0]; |
1297 | bias = weight * scale_bias[1]; |
1298 | } else { |
1299 | scale = weight * |
1300 | cpu_half2float(reinterpret_cast<const float16*>(scale_bias)[0]); |
1301 | bias = weight * |
1302 | cpu_half2float(reinterpret_cast<const float16*>(scale_bias)[1]); |
1303 | } |
1304 | |
1305 | for (int j = 0; j < block_size; ++j) { |
1306 | buf[j] = std::fma( |
1307 | scale, |
1308 | input |
1309 | [input_stride * idx + j + |
1310 | (scale_bias_last ? 0 : 2 * sizeof(float16))], |
1311 | buf[j] + bias); |
1312 | } |
1313 | |
1314 | ++current; |
1315 | } |
1316 | if (normalize_by_lengths && len) { |
1317 | float scale = 1.f / len; |
1318 | for (int j = 0; j < block_size; ++j) { |
1319 | buf[j] *= scale; |
1320 | } |
1321 | } |
1322 | for (int j = 0; j < block_size; ++j) { |
1323 | out[j] = is_same<OutType, float16>::value ? cpu_float2half_rn(buf[j]) |
1324 | : buf[j]; |
1325 | } |
1326 | out += output_stride; |
1327 | } |
1328 | return current == index_size; |
1329 | } else { |
1330 | if (input_stride == -1) { |
1331 | input_stride = block_size; |
1332 | } |
1333 | |
1334 | if (no_bag) { |
1335 | for (int m = 0; m < output_size; ++m) { |
1336 | memset(buf.data(), 0, sizeof(float) * block_size); |
1337 | int64_t idx = indices[m]; |
1338 | if (idx < 0 || idx >= data_size) { |
1339 | return false; |
1340 | } |
1341 | |
1342 | float w = 1.f; |
1343 | if (weights) { |
1344 | w = weights[m]; |
1345 | } |
1346 | |
1347 | for (int j = 0; j < block_size; ++j) { |
1348 | const InType* inptr = input + input_stride * idx + j; |
1349 | buf[j] = std::fma(w, convert_to_float_ref(*inptr, is_bf16), buf[j]); |
1350 | } |
1351 | for (int j = 0; j < block_size; ++j) { |
1352 | out[j] = convert_from_float_ref<OutType>(buf[j], is_bf16); |
1353 | } |
1354 | out += output_stride; |
1355 | } // m |
1356 | return true; |
1357 | } // no_bag |
1358 | |
1359 | // Reference implementation of FP32 SLS |
1360 | int64_t current = 0; |
1361 | for (int m = 0; m < output_size; ++m) { |
1362 | memset(buf.data(), 0, sizeof(float) * block_size); |
1363 | int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] |
1364 | : offsets_or_lengths[m]; |
1365 | if (current + len > index_size) { |
1366 | return false; |
1367 | } |
1368 | for (int i = 0; i < len; ++i) { |
1369 | int64_t idx = indices[current]; |
1370 | if (idx < 0 || idx >= data_size) { |
1371 | return false; |
1372 | } |
1373 | |
1374 | float w = 1.f; |
1375 | if (weights) { |
1376 | w = weights[is_weight_positional ? i : current]; |
1377 | } |
1378 | |
1379 | for (int j = 0; j < block_size; ++j) { |
1380 | const InType* inptr = input + input_stride * idx + j; |
1381 | buf[j] = std::fma(w, convert_to_float_ref(*inptr, is_bf16), buf[j]); |
1382 | } |
1383 | |
1384 | ++current; |
1385 | } |
1386 | if (normalize_by_lengths && len) { |
1387 | float scale = 1.f / len; |
1388 | for (int j = 0; j < block_size; ++j) { |
1389 | buf[j] *= scale; |
1390 | } |
1391 | } |
1392 | for (int j = 0; j < block_size; ++j) { |
1393 | out[j] = convert_from_float_ref<OutType>(buf[j], is_bf16); |
1394 | } |
1395 | out += output_stride; |
1396 | } |
1397 | return current == index_size; |
1398 | } |
1399 | } |
1400 | |
1401 | template <typename IndexType, typename OffsetType, typename OutType> |
1402 | bool EmbeddingSpMDMNBit_ref( |
1403 | int bit_rate, |
1404 | const int64_t block_size, |
1405 | const int64_t output_size, |
1406 | const int64_t index_size, |
1407 | const int64_t data_size, |
1408 | const uint8_t* input, |
1409 | const IndexType* indices, |
1410 | const OffsetType* offsets_or_lengths, |
1411 | const float* weights, // optional, can be null for non-weighted sum |
1412 | bool normalize_by_lengths, |
1413 | OutType* out, |
1414 | bool is_weight_positional, |
1415 | bool use_offsets, |
1416 | int64_t output_stride, |
1417 | int64_t input_stride, |
1418 | bool scale_bias_last) { |
1419 | assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4" ); |
1420 | int num_elem_per_byte = 8 / bit_rate; |
1421 | |
1422 | if (output_stride == -1) { |
1423 | output_stride = block_size; |
1424 | } |
1425 | |
1426 | // block_size is the number of elements and fused_block_size is the size of |
1427 | // an entire row, including scale and bias. |
1428 | const auto scale_bias_offset = 2 * sizeof(float16); |
1429 | if (input_stride == -1) { |
1430 | input_stride = (block_size + num_elem_per_byte - 1) / num_elem_per_byte + |
1431 | scale_bias_offset; |
1432 | } |
1433 | int64_t current = 0; |
1434 | vector<float> buf(block_size); |
1435 | for (int m = 0; m < output_size; ++m) { |
1436 | memset(buf.data(), 0, sizeof(float) * block_size); |
1437 | int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] |
1438 | : offsets_or_lengths[m]; |
1439 | if (current + len > index_size) { |
1440 | return false; |
1441 | } |
1442 | for (int i = 0; i < len; ++i) { |
1443 | int64_t idx = indices[current]; |
1444 | if (idx < 0 || idx >= data_size) { |
1445 | return false; |
1446 | } |
1447 | |
1448 | const float16* scale_bias = reinterpret_cast<const float16*>( |
1449 | input + input_stride * idx + |
1450 | (scale_bias_last |
1451 | ? (block_size + num_elem_per_byte - 1) / num_elem_per_byte |
1452 | : 0)); |
1453 | |
1454 | float weight = 1.0f; |
1455 | if (weights) { |
1456 | weight = weights[is_weight_positional ? i : current]; |
1457 | } |
1458 | const float scale = weight * cpu_half2float(scale_bias[0]); |
1459 | const float bias = weight * cpu_half2float(scale_bias[1]); |
1460 | |
1461 | for (int j = 0; j < block_size; ++j) { |
1462 | uint8_t quantized = input |
1463 | [input_stride * idx + j / num_elem_per_byte + |
1464 | (scale_bias_last ? 0 : scale_bias_offset)]; |
1465 | quantized >>= (j % num_elem_per_byte) * bit_rate; |
1466 | quantized &= (1 << bit_rate) - 1; |
1467 | |
1468 | buf[j] = std::fma(scale, quantized, buf[j] + bias); |
1469 | } |
1470 | |
1471 | ++current; |
1472 | } |
1473 | if (normalize_by_lengths && len) { |
1474 | float scale = 1.f / len; |
1475 | for (int j = 0; j < block_size; ++j) { |
1476 | buf[j] *= scale; |
1477 | } |
1478 | } |
1479 | for (int j = 0; j < block_size; ++j) { |
1480 | out[j] = std::is_same<OutType, float16>::value ? cpu_float2half_rn(buf[j]) |
1481 | : buf[j]; |
1482 | } |
1483 | out += output_stride; |
1484 | } |
1485 | return current == index_size; |
1486 | } |
1487 | |
1488 | template <typename IndexType, typename OffsetType, typename OutType> |
1489 | bool EmbeddingSpMDMFP8_ref( |
1490 | const int64_t block_size, |
1491 | const int64_t output_size, |
1492 | const int64_t index_size, |
1493 | const int64_t data_size, |
1494 | const uint8_t* input, |
1495 | const IndexType* indices, |
1496 | const OffsetType* offsets_or_lengths, |
1497 | const float* weights, |
1498 | bool normalize_by_lengths, |
1499 | OutType* out, |
1500 | bool is_weight_positional, |
1501 | bool use_offsets, |
1502 | int64_t output_stride, |
1503 | int64_t input_stride, |
1504 | int exponent_bits, |
1505 | int exponent_bias) { |
1506 | if (output_stride == -1) { |
1507 | output_stride = block_size; |
1508 | } |
1509 | |
1510 | vector<float> buf(block_size); |
1511 | |
1512 | if (input_stride == -1) { |
1513 | input_stride = block_size; |
1514 | } |
1515 | |
1516 | // Reference implementation of FP8 SLS. The algorithm is similar to FP32 SLS |
1517 | // except for the FP8->FP32 conversion after reading the embedding weight. |
1518 | int64_t current = 0; |
1519 | for (int m = 0; m < output_size; ++m) { |
1520 | memset(buf.data(), 0, sizeof(float) * block_size); |
1521 | int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] |
1522 | : offsets_or_lengths[m]; |
1523 | if (current + len > index_size) { |
1524 | return false; |
1525 | } |
1526 | for (int i = 0; i < len; ++i) { |
1527 | int64_t idx = indices[current]; |
1528 | if (idx < 0 || idx >= data_size) { |
1529 | return false; |
1530 | } |
1531 | |
1532 | float w = 1.f; |
1533 | if (weights) { |
1534 | w = weights[is_weight_positional ? i : current]; |
1535 | } |
1536 | |
1537 | for (int j = 0; j < block_size; ++j) { |
1538 | const uint8_t* inptr = input + input_stride * idx + j; |
1539 | float input_f; |
1540 | // Dequantize FP8 to FP32 before compute |
1541 | Float8ToFloat_ref(*inptr, &input_f, exponent_bits, exponent_bias); |
1542 | buf[j] = std::fma(w, input_f, buf[j]); |
1543 | } |
1544 | |
1545 | ++current; |
1546 | } |
1547 | if (normalize_by_lengths && len) { |
1548 | float scale = 1.f / len; |
1549 | for (int j = 0; j < block_size; ++j) { |
1550 | buf[j] *= scale; |
1551 | } |
1552 | } |
1553 | for (int j = 0; j < block_size; ++j) { |
1554 | out[j] = |
1555 | is_same<OutType, float16>::value ? cpu_float2half_rn(buf[j]) : buf[j]; |
1556 | } |
1557 | out += output_stride; |
1558 | } |
1559 | return current == index_size; |
1560 | } |
1561 | |
1562 | template <typename InType, typename IndexType, typename OffsetType> |
1563 | bool EmbeddingSpMDMRowWiseSparse_ref( |
1564 | const int64_t block_size, |
1565 | const int64_t output_size, |
1566 | const int64_t index_size, |
1567 | const int64_t uncompressed_data_size, |
1568 | // const int64_t compressed_data_size, |
1569 | const InType* input, |
1570 | const IndexType* indices, |
1571 | const int32_t* compressed_indices_table, |
1572 | const OffsetType* offsets_or_lengths, |
1573 | const float* weights, // optional, can be null for non-weighted sum |
1574 | bool normalize_by_lengths, |
1575 | float* out, |
1576 | bool is_weight_positional, |
1577 | bool use_offsets) { |
1578 | bool is8bit = is_same<InType, uint8_t>::value; |
1579 | |
1580 | if (is8bit) { |
1581 | // block_size is the number of elements and fused_block_size is the size |
1582 | // of an entire row, including scale and bias. |
1583 | const auto scale_bias_offset = 2 * sizeof(float); |
1584 | const int64_t fused_block_size = block_size + scale_bias_offset; |
1585 | int64_t current = 0; |
1586 | for (int m = 0; m < output_size; ++m) { |
1587 | memset(out, 0, sizeof(float) * block_size); |
1588 | int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] |
1589 | : offsets_or_lengths[m]; |
1590 | if (current + len > index_size) { |
1591 | return false; |
1592 | } |
1593 | for (int i = 0; i < len; ++i) { |
1594 | IndexType uncompressed_idx = indices[current]; |
1595 | if (uncompressed_idx < 0 || |
1596 | uncompressed_idx >= uncompressed_data_size) { |
1597 | return false; |
1598 | } |
1599 | IndexType idx = compressed_indices_table[uncompressed_idx]; |
1600 | if (idx == -1) { |
1601 | ++current; |
1602 | continue; |
1603 | } |
1604 | // if (idx < 0 || idx >= compressed_data_size) { |
1605 | // return false; |
1606 | // } |
1607 | |
1608 | const float* scale_bias = reinterpret_cast<const float*>( |
1609 | input + fused_block_size * idx + block_size); |
1610 | |
1611 | float weight = 1.0f; |
1612 | if (weights) { |
1613 | weight = weights[is_weight_positional ? i : current]; |
1614 | } |
1615 | const float scale = weight * scale_bias[0]; |
1616 | const float bias = weight * scale_bias[1]; |
1617 | |
1618 | for (int j = 0; j < block_size; ++j) { |
1619 | out[j] = |
1620 | std::fma(scale, input[fused_block_size * idx + j], out[j] + bias); |
1621 | } |
1622 | |
1623 | ++current; |
1624 | } |
1625 | if (normalize_by_lengths && len) { |
1626 | float scale = 1.f / len; |
1627 | for (int j = 0; j < block_size; ++j) { |
1628 | out[j] *= scale; |
1629 | } |
1630 | } |
1631 | out += block_size; |
1632 | } |
1633 | return current == index_size; |
1634 | } else { |
1635 | // Reference implementation of FP32 SLS |
1636 | int64_t current = 0; |
1637 | for (int m = 0; m < output_size; ++m) { |
1638 | memset(out, 0, sizeof(float) * block_size); |
1639 | int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] |
1640 | : offsets_or_lengths[m]; |
1641 | if (current + len > index_size) { |
1642 | return false; |
1643 | } |
1644 | for (int i = 0; i < len; ++i) { |
1645 | IndexType uncompressed_idx = indices[current]; |
1646 | if (uncompressed_idx < 0 || |
1647 | uncompressed_idx >= uncompressed_data_size) { |
1648 | return false; |
1649 | } |
1650 | IndexType idx = compressed_indices_table[uncompressed_idx]; |
1651 | if (idx == -1) { |
1652 | ++current; |
1653 | continue; |
1654 | } |
1655 | // if (idx < 0 || idx >= compressed_data_size) { |
1656 | // return false; |
1657 | // } |
1658 | |
1659 | float w = 1.f; |
1660 | if (weights) { |
1661 | w = weights[is_weight_positional ? i : current]; |
1662 | } |
1663 | |
1664 | for (int j = 0; j < block_size; ++j) { |
1665 | const InType* inptr = input + block_size * idx + j; |
1666 | out[j] = std::fma( |
1667 | w, |
1668 | is_same<InType, float16>::value ? cpu_half2float(*inptr) : *inptr, |
1669 | out[j]); |
1670 | } |
1671 | |
1672 | ++current; |
1673 | } |
1674 | if (normalize_by_lengths && len) { |
1675 | float scale = 1.f / len; |
1676 | for (int j = 0; j < block_size; ++j) { |
1677 | out[j] *= scale; |
1678 | } |
1679 | } |
1680 | out += block_size; |
1681 | } |
1682 | return current == index_size; |
1683 | } |
1684 | } |
1685 | |
1686 | template <typename IndexType, typename OffsetType> |
1687 | bool EmbeddingSpMDMNBitRowWiseSparse_ref( |
1688 | int bit_rate, |
1689 | const int64_t block_size, |
1690 | const int64_t output_size, |
1691 | const int64_t index_size, |
1692 | const int64_t uncompressed_data_size, |
1693 | // const int64_t compressed_data_size, |
1694 | const uint8_t* input, |
1695 | const IndexType* indices, |
1696 | const int32_t* compressed_indices_table, |
1697 | const OffsetType* offsets_or_lengths, |
1698 | const float* weights, // optional, can be null for non-weighted sum |
1699 | bool normalize_by_lengths, |
1700 | float* out, |
1701 | bool is_weight_positional, |
1702 | bool use_offsets) { |
1703 | assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4" ); |
1704 | int num_elem_per_byte = 8 / bit_rate; |
1705 | |
1706 | // block_size is the number of elements and fused_block_size is the size of |
1707 | // an entire row, including scale and bias. |
1708 | const auto scale_bias_offset = 2 * sizeof(float16); |
1709 | const int64_t fused_block_size = |
1710 | (block_size + num_elem_per_byte - 1) / num_elem_per_byte + |
1711 | scale_bias_offset; |
1712 | int64_t current = 0; |
1713 | for (int m = 0; m < output_size; ++m) { |
1714 | memset(out, 0, sizeof(float) * block_size); |
1715 | int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] |
1716 | : offsets_or_lengths[m]; |
1717 | if (current + len > index_size) { |
1718 | return false; |
1719 | } |
1720 | for (int i = 0; i < len; ++i, ++current) { |
1721 | IndexType uncompressed_idx = indices[current]; |
1722 | if (uncompressed_idx < 0 || uncompressed_idx >= uncompressed_data_size) { |
1723 | return false; |
1724 | } |
1725 | IndexType idx = compressed_indices_table[uncompressed_idx]; |
1726 | if (idx == -1) { |
1727 | continue; |
1728 | } |
1729 | // if (idx < 0 || idx >= compressed_data_size) { |
1730 | // return false; |
1731 | // } |
1732 | |
1733 | const float16* scale_bias = reinterpret_cast<const float16*>( |
1734 | input + fused_block_size * idx + |
1735 | (block_size + num_elem_per_byte - 1) / num_elem_per_byte); |
1736 | |
1737 | float weight = 1.0f; |
1738 | if (weights) { |
1739 | weight = weights[is_weight_positional ? i : current]; |
1740 | } |
1741 | const float scale = weight * cpu_half2float(scale_bias[0]); |
1742 | const float bias = weight * cpu_half2float(scale_bias[1]); |
1743 | |
1744 | for (int j = 0; j < block_size; ++j) { |
1745 | uint8_t quantized = |
1746 | input[fused_block_size * idx + j / num_elem_per_byte]; |
1747 | quantized >>= (j % num_elem_per_byte) * bit_rate; |
1748 | quantized &= (1 << bit_rate) - 1; |
1749 | |
1750 | out[j] = std::fma(scale, quantized, out[j] + bias); |
1751 | } |
1752 | } |
1753 | if (normalize_by_lengths && len) { |
1754 | float scale = 1.f / len; |
1755 | for (int j = 0; j < block_size; ++j) { |
1756 | out[j] *= scale; |
1757 | } |
1758 | } |
1759 | out += block_size; |
1760 | } |
1761 | return current == index_size; |
1762 | } |
1763 | |
1764 | template <typename IndexType> |
1765 | int sparse_adagrad_ref( |
1766 | int num_rows, // number of rows reading |
1767 | int block_size, // number of parameters per rows |
1768 | uint64_t param_size, // total number of parameters |
1769 | float* w, // input parameters |
1770 | const float* g, // input gradients |
1771 | float* h, // input momentums |
1772 | const IndexType* indices, // indices of each row |
1773 | float epsilon, |
1774 | float lr, |
1775 | float weight_decay, |
1776 | const double* counter, |
1777 | const int64_t counter_halflife) { |
1778 | for (auto i = 0; i < num_rows; ++i) { |
1779 | uint64_t idx = indices[i]; |
1780 | auto offsetI = i * block_size; |
1781 | auto offsetIdx = idx * block_size; |
1782 | |
1783 | if (block_size + offsetIdx > param_size) { |
1784 | return i; |
1785 | } |
1786 | |
1787 | float freq = |
1788 | (counter && counter[idx] > 0) ? counter_halflife / counter[idx] : 1.0; |
1789 | |
1790 | const float* g_; |
1791 | const float* h_; |
1792 | const float* w_; |
1793 | float* nh_; |
1794 | float* nw_; |
1795 | |
1796 | g_ = g + offsetI; |
1797 | h_ = h + offsetIdx; |
1798 | w_ = w + offsetIdx; |
1799 | nh_ = h + offsetIdx; |
1800 | nw_ = w + offsetIdx; |
1801 | |
1802 | for (auto j = 0; j < block_size; ++j) { |
1803 | float gj = std::fma(weight_decay * freq, w_[j], g_[j]); |
1804 | float hj = h_[j] + gj * gj; |
1805 | nh_[j] = hj; |
1806 | nw_[j] = w_[j] + lr * gj / (std::sqrt(hj) + epsilon); |
1807 | } |
1808 | } |
1809 | return num_rows; |
1810 | } |
1811 | |
1812 | template <typename IndexType> |
1813 | int rowwise_sparse_adagrad_ref( |
1814 | int num_rows, // number of rows reading |
1815 | int block_size, // number of parameters per rows |
1816 | uint64_t param_size, // total number of parameters |
1817 | float* w, // input parameters |
1818 | const float* g, // input gradients |
1819 | float* h, // input momentums |
1820 | const IndexType* indices, // indices of each row |
1821 | float epsilon, |
1822 | float lr, |
1823 | float weight_decay, |
1824 | const double* counter, |
1825 | const int64_t counter_halflife) { |
1826 | for (auto i = 0; i < num_rows; ++i) { |
1827 | uint64_t idx = indices[i]; |
1828 | auto offsetI = i * block_size; |
1829 | auto offsetIdx = idx * block_size; |
1830 | |
1831 | if (block_size + offsetIdx > param_size) { |
1832 | return i; |
1833 | } |
1834 | |
1835 | float freq = |
1836 | (counter && counter[idx] > 0) ? counter_halflife / counter[idx] : 1.0; |
1837 | |
1838 | const float* g_; |
1839 | float* h_; |
1840 | float* w_; |
1841 | |
1842 | g_ = g + offsetI; |
1843 | h_ = h + idx; // This is different from sparse adagrad |
1844 | w_ = w + offsetIdx; |
1845 | |
1846 | float final_sum = 0.0f; |
1847 | // Note the following code assumes fbgemm will generate AVX2 code for |
1848 | // horizontal reduction, which is OK for now because fbgemm always uses |
1849 | // AVX2 for SparseAdagrad due to its performance is bounded by memory |
1850 | // bandwidth hence no speedup from AVX512. Non-vectorized version would be |
1851 | // just for (auto j = 0; j < block_size; ++j) { |
1852 | // float gj = g_[j]; |
1853 | // final_sum += gj * gj; |
1854 | // } |
1855 | constexpr int VLEN = 8; |
1856 | array<float, VLEN> partial_sum = {0.0f}; |
1857 | for (auto j = 0; j < block_size; ++j) { |
1858 | float gj = std::fma(weight_decay * freq, w_[j], g_[j]); |
1859 | partial_sum[j % VLEN] += gj * gj; |
1860 | } |
1861 | final_sum = ((partial_sum[0] + partial_sum[1]) + |
1862 | (partial_sum[2] + partial_sum[3])) + |
1863 | ((partial_sum[4] + partial_sum[5]) + (partial_sum[6] + partial_sum[7])); |
1864 | final_sum /= block_size; |
1865 | float hi = *h_ = *h_ + final_sum; |
1866 | float float_step = lr / (std::sqrt(hi) + epsilon); |
1867 | |
1868 | for (auto j = 0; j < block_size; ++j) { |
1869 | float gj = std::fma(weight_decay * freq, w_[j], g_[j]); |
1870 | w_[j] += gj * float_step; |
1871 | } |
1872 | } |
1873 | return num_rows; |
1874 | } |
1875 | |
1876 | template <typename DataType, typename IndexType, typename OffsetType> |
1877 | int rowwise_sparse_adagrad_fused_ref( |
1878 | int64_t block_size, |
1879 | int64_t output_size, |
1880 | int64_t index_size, |
1881 | int64_t data_size, |
1882 | DataType* w, |
1883 | const float* g, |
1884 | float* h, |
1885 | const IndexType* indices, |
1886 | const OffsetType* offsets_or_lengths, |
1887 | float epsilon, |
1888 | float lr, |
1889 | bool use_offsets, |
1890 | bool use_stochastic_rounding, |
1891 | int emu_vector_size, |
1892 | int64_t grad_stride) { |
1893 | if (grad_stride == -1) { |
1894 | grad_stride = block_size; |
1895 | } |
1896 | |
1897 | constexpr bool isFloat16w = std::is_same<float16, DataType>::value; |
1898 | // Local random buffer to emulate SIMD vector |
1899 | // R: generated 32bit base random numbers |
1900 | // r: extracted 8-bit for rounding |
1901 | constexpr int VLEN_MAX = 16; |
1902 | uint32_t R[VLEN_MAX], r[VLEN_MAX]; |
1903 | int vlen = emu_vector_size; |
1904 | if (vlen != 8 && vlen != 16) { |
1905 | // Raise error as it may cause buffer overflow |
1906 | cerr << "Not supported emu_vector_size: " << emu_vector_size << endl; |
1907 | return 0; |
1908 | } |
1909 | |
1910 | int64_t current = 0; |
1911 | for (int m = 0; m < output_size; ++m) { |
1912 | int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] |
1913 | : offsets_or_lengths[m]; |
1914 | if (current + len > index_size) { |
1915 | return false; |
1916 | } |
1917 | const float* g_ = g + m * grad_stride; |
1918 | // Note the following code assumes fbgemm will generate AVX2 code for |
1919 | // horizontal reduction, which is OK for now because fbgemm always uses |
1920 | // AVX2 for SparseAdagrad due to its performance is bounded by memory |
1921 | // bandwidth hence no speedup from AVX512. Non-vectorized version would be |
1922 | // just for (auto j = 0; j < block_size; ++j) { |
1923 | // float gj = g_[j]; |
1924 | // final_sum += gj * gj; |
1925 | // } |
1926 | constexpr int VLEN_AVX2 = 8; |
1927 | array<float, VLEN_AVX2> partial_sum = {0.0f}; |
1928 | for (auto j = 0; j < block_size; ++j) { |
1929 | float gj = g_[j]; |
1930 | partial_sum[j % VLEN_AVX2] += gj * gj; |
1931 | } |
1932 | float final_sum = ((partial_sum[0] + partial_sum[1]) + |
1933 | (partial_sum[2] + partial_sum[3])) + |
1934 | ((partial_sum[4] + partial_sum[5]) + (partial_sum[6] + partial_sum[7])); |
1935 | final_sum /= block_size; |
1936 | |
1937 | for (int i = 0; i < len; ++i, ++current) { |
1938 | int64_t idx = indices[current]; |
1939 | if (idx < 0 || idx >= data_size) { |
1940 | return false; |
1941 | } |
1942 | |
1943 | float* h_ = h + idx; |
1944 | DataType* w_ = w + idx * block_size; |
1945 | |
1946 | float hi = *h_ = *h_ + final_sum; |
1947 | float float_step = lr / (std::sqrt(hi) + epsilon); |
1948 | |
1949 | int nvec = (block_size + vlen - 1) / vlen; |
1950 | int rem = (block_size % vlen) ? (block_size % vlen) : vlen; |
1951 | |
1952 | // Emulate JIT behavior of stochastic rounding with vector-length |
1953 | // |
1954 | // Generate R buffer every 4 steps of nvec loop. Each 8-bit in R |
1955 | // (uint32_t) will be used once. It is shifted to bits[5..13] then |
1956 | // added to FP32 weights before FP16 conversion. |
1957 | // |
1958 | // The shifted 8 bit region |
1959 | // +-------+--------+--------+--------+ |
1960 | // | | | xxxxx|xxx | |
1961 | // 31 23 15 7 0 |
1962 | // |
1963 | // Half float has 10 bits of mantissa, and float has 23, we are shifting |
1964 | // the bits to cover the region where half floats can't represent data. |
1965 | // This is bit 13-23 of the mantissa of fp32. |
1966 | // This will be effectively adding a random variable of [0,1] |
1967 | |
1968 | for (int n = 0; n < nvec; ++n) { |
1969 | int cur_vlen = (n == nvec - 1) ? rem : vlen; |
1970 | int sr_idx = n % 4; |
1971 | |
1972 | if (isFloat16w && use_stochastic_rounding) { |
1973 | if (sr_idx == 0) { |
1974 | for (int v = 0; v < vlen; ++v) { |
1975 | R[v] = rnd128_next(v, vlen); |
1976 | r[v] = (R[v] & 0xFFU) << 5; |
1977 | } |
1978 | } else if (sr_idx == 1) { |
1979 | for (int v = 0; v < vlen; ++v) { |
1980 | r[v] = ((R[v] & 0xFF00U) >> 8) << 5; |
1981 | } |
1982 | } else if (sr_idx == 2) { |
1983 | for (int v = 0; v < vlen; ++v) { |
1984 | r[v] = ((R[v] & 0xFF0000U) >> 16) << 5; |
1985 | } |
1986 | } else { // 3 |
1987 | for (int v = 0; v < vlen; ++v) { |
1988 | r[v] = ((R[v] & 0xFF000000U) >> 24) << 5; |
1989 | } |
1990 | } |
1991 | } |
1992 | |
1993 | for (int v = 0; v < cur_vlen; ++v) { |
1994 | int j = n * vlen + v; |
1995 | if (isFloat16w) { |
1996 | union { |
1997 | float w_f32; |
1998 | uint32_t w_i32; |
1999 | }; |
2000 | w_f32 = cpu_half2float(w_[j]); |
2001 | w_f32 = std::fma(float_step, g_[j], w_f32); |
2002 | if (use_stochastic_rounding) { |
2003 | w_i32 += r[v]; |
2004 | } |
2005 | // Use truncate rounding to 'counterwork' the random added part |
2006 | w_[j] = cpu_float2half_rz(w_f32); |
2007 | } else { // float |
2008 | w_[j] += g_[j] * float_step; |
2009 | } |
2010 | } |
2011 | } |
2012 | } |
2013 | } |
2014 | |
2015 | return current == index_size; |
2016 | } |
2017 | |
2018 | template FBGEMM_API void transposeConvWeights( |
2019 | const conv_param_t<1>& conv_p, |
2020 | const std::int8_t* src, |
2021 | std::int8_t* dest); |
2022 | |
2023 | template FBGEMM_API void transposeConvWeights( |
2024 | const conv_param_t<2>& conv_p, |
2025 | const std::int8_t* src, |
2026 | std::int8_t* dest); |
2027 | |
2028 | template FBGEMM_API void transposeConvWeights( |
2029 | const conv_param_t<3>& conv_p, |
2030 | const std::int8_t* src, |
2031 | std::int8_t* dest); |
2032 | |
2033 | #define INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ |
2034 | template FBGEMM_API bool EmbeddingSpMDM_ref( \ |
2035 | const int64_t block_size, \ |
2036 | const int64_t output_size, \ |
2037 | const int64_t index_size, \ |
2038 | const int64_t data_size, \ |
2039 | const IN_TYPE* input, \ |
2040 | const INDEX_TYPE* indices, \ |
2041 | const OFFSET_TYPE* offsets_or_lengths, \ |
2042 | const float* weights, \ |
2043 | bool normalize_by_lengths, \ |
2044 | OUT_TYPE* out, \ |
2045 | bool is_weight_positional, \ |
2046 | bool use_offsets, \ |
2047 | int64_t input_stride, \ |
2048 | int64_t output_stride, \ |
2049 | bool scale_bias_last, \ |
2050 | bool no_bag, \ |
2051 | bool is_bf16); |
2052 | |
2053 | #define INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \ |
2054 | INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float) \ |
2055 | INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float16) \ |
2056 | template FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( \ |
2057 | const int64_t block_size, \ |
2058 | const int64_t output_size, \ |
2059 | const int64_t index_size, \ |
2060 | const int64_t uncompressed_data_size, \ |
2061 | const IN_TYPE* input, \ |
2062 | const INDEX_TYPE* indices, \ |
2063 | const int32_t* compressed_indices_table, \ |
2064 | const OFFSET_TYPE* offsets_or_lengths, \ |
2065 | const float* weights, \ |
2066 | bool normalize_by_lengths, \ |
2067 | float* out, \ |
2068 | bool is_weight_positional, \ |
2069 | bool use_offsets); |
2070 | |
2071 | #define INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, INDEX_TYPE) \ |
2072 | INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, std::int32_t) \ |
2073 | INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, std::int64_t) |
2074 | |
2075 | #define INSTANTIATE_SPMDM_INDEX_T(IN_TYPE) \ |
2076 | INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, std::int32_t) \ |
2077 | INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, std::int64_t) |
2078 | |
2079 | INSTANTIATE_SPMDM_INDEX_T(float) |
2080 | INSTANTIATE_SPMDM_INDEX_T(float16) |
2081 | INSTANTIATE_SPMDM_INDEX_T(std::uint8_t) |
2082 | |
2083 | #undef INSTANTIATE_SPMDM_INDEX_T |
2084 | #undef INSTANTIATE_SPMDM_OFFSET_T |
2085 | #undef INSTANTIATE_SPMDM_OUT_T |
2086 | #undef INSTANTIATE_SPMDM_BASE |
2087 | |
2088 | #define INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ |
2089 | template FBGEMM_API bool EmbeddingSpMDMNBit_ref( \ |
2090 | int bit_rate, \ |
2091 | const int64_t block_size, \ |
2092 | const int64_t output_size, \ |
2093 | const int64_t index_size, \ |
2094 | const int64_t data_size, \ |
2095 | const uint8_t* input, \ |
2096 | const INDEX_TYPE* indices, \ |
2097 | const OFFSET_TYPE* offsets_or_lengths, \ |
2098 | const float* weights, \ |
2099 | bool normalize_by_lengths, \ |
2100 | OUT_TYPE* out, \ |
2101 | bool is_weight_positional, \ |
2102 | bool use_offsets, \ |
2103 | int64_t output_stride, \ |
2104 | int64_t input_stride, \ |
2105 | bool scale_bias_last); \ |
2106 | template FBGEMM_API bool EmbeddingSpMDMFP8_ref( \ |
2107 | const int64_t block_size, \ |
2108 | const int64_t output_size, \ |
2109 | const int64_t index_size, \ |
2110 | const int64_t data_size, \ |
2111 | const uint8_t* input, \ |
2112 | const INDEX_TYPE* indices, \ |
2113 | const OFFSET_TYPE* offsets_or_lengths, \ |
2114 | const float* weights, \ |
2115 | bool normalize_by_lengths, \ |
2116 | OUT_TYPE* out, \ |
2117 | bool is_weight_positional, \ |
2118 | bool use_offsets, \ |
2119 | int64_t output_stride, \ |
2120 | int64_t input_stride, \ |
2121 | int exponent_bits, \ |
2122 | int exponent_bias); |
2123 | |
2124 | #define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \ |
2125 | INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float) \ |
2126 | INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \ |
2127 | template FBGEMM_API bool EmbeddingSpMDMNBitRowWiseSparse_ref( \ |
2128 | int bit_rate, \ |
2129 | const int64_t block_size, \ |
2130 | const int64_t output_size, \ |
2131 | const int64_t index_size, \ |
2132 | const int64_t uncompressed_data_size, \ |
2133 | const uint8_t* input, \ |
2134 | const INDEX_TYPE* indices, \ |
2135 | const int32_t* compressed_indices_table, \ |
2136 | const OFFSET_TYPE* offsets_or_lengths, \ |
2137 | const float* weights, \ |
2138 | bool normalize_by_lengths, \ |
2139 | float* out, \ |
2140 | bool is_weight_positional, \ |
2141 | bool use_offsets); |
2142 | |
2143 | #define INSTANTIATE_SPMDM_OFFSET_T(INDEX_TYPE) \ |
2144 | INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int32_t) \ |
2145 | INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int64_t) |
2146 | |
2147 | INSTANTIATE_SPMDM_OFFSET_T(int32_t) |
2148 | INSTANTIATE_SPMDM_OFFSET_T(int64_t) |
2149 | |
2150 | #undef INSTANTIATE_SPMDM_OFFSET_T |
2151 | #undef INSTANTIATE_SPMDM_OUT_T |
2152 | #undef INSTANTIATE_SPMDM_BASE |
2153 | |
2154 | template FBGEMM_API int sparse_adagrad_ref( |
2155 | int num_rows, // number of rows reading |
2156 | int block_size, // number of parameters per rows |
2157 | std::uint64_t param_size, // total number of parameters |
2158 | float* w, // input parameters |
2159 | const float* g, // input gradients |
2160 | float* h, // input momentums |
2161 | const std::int64_t* indices, // indices of each row |
2162 | float epsilon, |
2163 | float lr, |
2164 | float weight_decay, |
2165 | const double* counter, |
2166 | const int64_t counter_halflife); |
2167 | |
2168 | template FBGEMM_API int sparse_adagrad_ref( |
2169 | int num_rows, // number of rows reading |
2170 | int block_size, // number of parameters per rows |
2171 | std::uint64_t param_size, // total number of parameters |
2172 | float* w, // input parameters |
2173 | const float* g, // input gradients |
2174 | float* h, // input momentums |
2175 | const std::int32_t* indices, // indices of each row |
2176 | float epsilon, |
2177 | float lr, |
2178 | float weight_decay, |
2179 | const double* counter, |
2180 | const int64_t counter_halflife); |
2181 | |
2182 | template FBGEMM_API int rowwise_sparse_adagrad_ref( |
2183 | int num_rows, // number of rows reading |
2184 | int block_size, // number of parameters per rows |
2185 | std::uint64_t param_size, // total number of parameters |
2186 | float* w, // input parameters |
2187 | const float* g, // input gradients |
2188 | float* h, // input momentums |
2189 | const std::int64_t* indices, // indices of each row |
2190 | float epsilon, |
2191 | float lr, |
2192 | float weight_decay, |
2193 | const double* counter, |
2194 | const int64_t counter_halflife); |
2195 | |
2196 | template FBGEMM_API int rowwise_sparse_adagrad_ref( |
2197 | int num_rows, // number of rows reading |
2198 | int block_size, // number of parameters per rows |
2199 | std::uint64_t param_size, // total number of parameters |
2200 | float* w, // input parameters |
2201 | const float* g, // input gradients |
2202 | float* h, // input momentums |
2203 | const std::int32_t* indices, // indices of each row |
2204 | float epsilon, |
2205 | float lr, |
2206 | float weight_decay, |
2207 | const double* counter, |
2208 | const int64_t counter_halflife); |
2209 | |
2210 | #define INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, OFFSET_TYPE) \ |
2211 | template FBGEMM_API int rowwise_sparse_adagrad_fused_ref( \ |
2212 | int64_t block_size, \ |
2213 | int64_t output_size, \ |
2214 | int64_t index_size, \ |
2215 | int64_t data_size, \ |
2216 | DATA_TYPE* w, \ |
2217 | const float* g, \ |
2218 | float* h, \ |
2219 | const INDEX_TYPE* indices, \ |
2220 | const OFFSET_TYPE* offsets_or_lengths, \ |
2221 | float epsilon, \ |
2222 | float lr, \ |
2223 | bool use_offsets, \ |
2224 | bool use_stochastic_rounding, \ |
2225 | int emu_vector_size, \ |
2226 | int64_t grad_stride); |
2227 | |
2228 | #define INSTANTIATE_SPMDM_OFFSET_T(DATA_TYPE, INDEX_TYPE) \ |
2229 | INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, int32_t) \ |
2230 | INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, int64_t) |
2231 | |
2232 | #define INSTANTIATE_SPMDM_INDEX_T(DATA_TYPE) \ |
2233 | INSTANTIATE_SPMDM_OFFSET_T(DATA_TYPE, int32_t) \ |
2234 | INSTANTIATE_SPMDM_OFFSET_T(DATA_TYPE, int64_t) |
2235 | |
2236 | INSTANTIATE_SPMDM_INDEX_T(float) |
2237 | INSTANTIATE_SPMDM_INDEX_T(float16) |
2238 | |
2239 | #undef INSTANTIATE_SPMDM_OFFSET_T |
2240 | #undef INSTANTIATE_SPMDM_BASE |
2241 | |
2242 | template <typename IndexType> |
2243 | FBGEMM_API void compressed_indices_remap_ref( |
2244 | std::int32_t offsets_numel, |
2245 | const IndexType* indices, |
2246 | const int32_t* compressed_indices_mapping, |
2247 | const IndexType* offsets, |
2248 | const float* weights, // optional, can be null, |
2249 | IndexType* out_indices, |
2250 | IndexType* out_offsets, |
2251 | float* out_weights) { |
2252 | bool has_per_sample_weights = (weights != nullptr); |
2253 | out_offsets[0] = offsets[0]; |
2254 | IndexType j = 0; |
2255 | for (int i = 1; i < offsets_numel; i++) { |
2256 | for (int32_t k = offsets[i - 1]; k < offsets[i]; k++) { |
2257 | if (compressed_indices_mapping[indices[k]] != -1) { |
2258 | out_indices[j] = compressed_indices_mapping[indices[k]]; |
2259 | if (has_per_sample_weights) { |
2260 | out_weights[j] = weights[k]; |
2261 | } |
2262 | j++; |
2263 | } |
2264 | } |
2265 | out_offsets[i] = j; |
2266 | } |
2267 | } |
2268 | |
2269 | #define INSTANTIATE_REMAP_BASE(INDEX_TYPE) \ |
2270 | template FBGEMM_API void compressed_indices_remap_ref( \ |
2271 | std::int32_t offsets_numel, \ |
2272 | const INDEX_TYPE* indices, \ |
2273 | const int32_t* compressed_indices_mapping, \ |
2274 | const INDEX_TYPE* offsets, \ |
2275 | const float* weights, \ |
2276 | INDEX_TYPE* out_indices, \ |
2277 | INDEX_TYPE* out_offsets, \ |
2278 | float* out_weights); |
2279 | |
2280 | INSTANTIATE_REMAP_BASE(int32_t) |
2281 | INSTANTIATE_REMAP_BASE(int64_t) |
2282 | |
2283 | #undef INSTANTIATE_REMAP_BASE |
2284 | |
2285 | } // namespace fbgemm |
2286 | |