1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #pragma once |
8 | |
9 | #include <algorithm> |
10 | #include <cstdint> |
11 | |
12 | #include "fbgemm/ConvUtils.h" |
13 | #include "fbgemm/FbgemmI8Spmdm.h" |
14 | #include "fbgemm/Types.h" |
15 | |
16 | namespace fbgemm { |
17 | |
18 | /** |
19 | * @brief Reference implementation of requantization step. |
20 | * int32 multiplier |
21 | * @param bias can be nullptr |
22 | */ |
23 | FBGEMM_API void requantize_u8acc32_ref( |
24 | int M, |
25 | int N, |
26 | int ld, |
27 | const std::int32_t* inp, |
28 | std::uint8_t* out, |
29 | std::int32_t C_multiplier, |
30 | std::int32_t C_right_shift, |
31 | std::int32_t C_zero_point, |
32 | std::int32_t A_zero_point, |
33 | std::int32_t B_zero_point, |
34 | const std::int32_t* row_offsets, |
35 | const std::int32_t* col_offsets, |
36 | const std::int32_t* bias, |
37 | bool fuse_relu = false); |
38 | |
39 | /** |
40 | * @brief Reference implementation of requantization step. |
41 | * float multiplier |
42 | * @param bias can be nullptr |
43 | * @param ncols_per_quant_group the number of columns share the same |
44 | * quantization parameter. |
45 | * ncols_per_quant_group == N : per-tensor quantization |
46 | * ncols_per_quant_group == N / groups : per-group quantization |
47 | * ncols_per_quant_group == 1 : per-channel quantization |
48 | */ |
49 | FBGEMM_API void requantize_u8acc32_ref( |
50 | int M, |
51 | int N, |
52 | int ld, |
53 | const std::int32_t* inp, |
54 | std::uint8_t* out, |
55 | const float* C_multiplier, |
56 | std::int32_t C_zero_point, |
57 | std::int32_t A_zero_point, |
58 | const std::int32_t* B_zero_point, |
59 | const std::int32_t* row_offsets, |
60 | const std::int32_t* col_offsets, |
61 | const std::int32_t* bias, |
62 | int ncols_per_quant_group, |
63 | bool fuse_relu = false); |
64 | |
65 | /** |
66 | * @brief Reference implementation of matrix multiply with uint8 for A, |
67 | * int8 for B, and 32-bit accumulation. |
68 | */ |
69 | FBGEMM_API void matmul_u8i8acc32_ref( |
70 | int M, |
71 | int N, |
72 | int K, |
73 | int lda, |
74 | int ldb, |
75 | int ldc, |
76 | const std::uint8_t* Aint8, |
77 | const std::int8_t* Bint8, |
78 | std::int32_t* Cint32); |
79 | |
80 | /** |
81 | * @brief Reference implementation of matrix multiply with uint 8 for A, |
82 | * int8 for B, and 16-bit accumulation. |
83 | */ |
84 | FBGEMM_API void matmul_u8i8acc16_ref( |
85 | int M, |
86 | int N, |
87 | int K, |
88 | int lda, |
89 | int ldb, |
90 | int ldc, |
91 | int brow, |
92 | const std::uint8_t* Aint8, |
93 | const std::int8_t* Bint8, |
94 | std::int32_t* Cint32); |
95 | |
96 | /** |
97 | * @brief Reference implementation of cblas_sgemm in MKL/BLAS. |
98 | */ |
99 | FBGEMM_API void cblas_sgemm_ref( |
100 | const matrix_op_t transa, |
101 | const matrix_op_t transb, |
102 | const int m, |
103 | const int n, |
104 | const int k, |
105 | float alpha, |
106 | const float* Afp32, |
107 | int lda, |
108 | const float* Bfp32, |
109 | int ldb, |
110 | float beta, |
111 | float* Cfp32, |
112 | int ldc); |
113 | |
114 | FBGEMM_API void cblas_gemm_i64_i64acc_ref( |
115 | matrix_op_t transa, |
116 | matrix_op_t transb, |
117 | int M, |
118 | int N, |
119 | int K, |
120 | const std::int64_t* A, |
121 | int lda, |
122 | const std::int64_t* B, |
123 | int ldb, |
124 | bool accumulate, |
125 | std::int64_t* C, |
126 | int ldc); |
127 | |
128 | /** |
129 | * @brief Reference implementation to compute row_offsets (sums of rows of A). |
130 | */ |
131 | FBGEMM_API void row_offsets_u8acc32_ref( |
132 | int M, |
133 | int K, |
134 | int ld, |
135 | const std::uint8_t* Aint8, |
136 | std::int32_t* row_offsets); |
137 | |
138 | /** |
139 | * @brief Reference implementation to compute adjusted col_offsets (sum of |
140 | * columns of B and adjusted with B_zero_point) |
141 | * |
142 | * @param ncols_per_quant_group see ncols_per_quant_group in |
143 | * requantize_u8acc32_ref |
144 | */ |
145 | FBGEMM_API void col_offsets_with_zero_pt_s8acc32_ref( |
146 | int K, |
147 | int N, |
148 | int ld, |
149 | const std::int8_t* Bint8, |
150 | const std::int32_t* B_zero_point, |
151 | std::int32_t* col_offsets, |
152 | int ncols_per_quant_group); |
153 | |
154 | /** |
155 | * @brief Reference implementation of SPMDM (sparse matrix times dense matrix). |
156 | * |
157 | * @param groups when > 1, for gth group, we multiply |
158 | * A[:,g*(A.ncols/groups):(g+1)*(A.ncols/groups)] sub-matrix with |
159 | * B[:,g*(B.ncols/groups):(g+1)*(B.ncols/groups)] sub-matrix . |
160 | */ |
161 | FBGEMM_API void spmdm_ref( |
162 | int M, |
163 | const std::uint8_t* A, |
164 | int lda, |
165 | CompressedSparseColumn& B, |
166 | bool accumulation, |
167 | std::int32_t* C, |
168 | int ldc, |
169 | int groups = 1); |
170 | |
171 | /* |
172 | * @brief Trim a 32-bit integer to a 16-bit integer. |
173 | */ |
174 | int32_t clip_16bit(int32_t x); |
175 | |
176 | /* |
177 | * @brief Reference implementation of convolution operation. |
178 | * The activations A are assumed to be in NHiWiC format. |
179 | * The filters B are assumed to be in RSCK format. |
180 | * The output C is assumed to be in NHoWoC format. |
181 | */ |
182 | template <int SPATIAL_DIM = 2> |
183 | FBGEMM_API void conv_ref( |
184 | const conv_param_t<SPATIAL_DIM>& conv_p, |
185 | const std::uint8_t* A, |
186 | std::int32_t A_zero_point, |
187 | const std::int8_t* B, |
188 | std::int32_t* C); |
189 | |
190 | /* |
191 | * @brief Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format. |
192 | */ |
193 | template <int SPATIAL_DIM = 2> |
194 | FBGEMM_API void transposeConvWeights( |
195 | const conv_param_t<SPATIAL_DIM>& conv_p, |
196 | const std::int8_t* src, |
197 | std::int8_t* dest); |
198 | |
199 | /* |
200 | * @brief Reference implementation of im2col operation. |
201 | * |
202 | * For 2D: |
203 | * The input A is assumed to be in NHiWiC format. |
204 | * The output A is assumed to be in NHoWoRSC format. |
205 | * |
206 | * For 3D: |
207 | * The input A is assumed to be in NTiHiWiC format. |
208 | * The output A is assumed to be in NToHoWoK0K1K2C format. |
209 | */ |
210 | template <int SPATIAL_DIM = 2> |
211 | FBGEMM_API void im2col_ref( |
212 | const conv_param_t<SPATIAL_DIM>& conv_p, |
213 | const std::uint8_t* A, |
214 | std::int32_t A_zero_point, |
215 | std::uint8_t* Ao); |
216 | |
217 | template < |
218 | typename InType = std::uint8_t, |
219 | typename IndexType = std::int64_t, |
220 | typename OffsetType = std::int32_t, |
221 | typename OutType = float> |
222 | FBGEMM_API bool EmbeddingSpMDM_ref( |
223 | const std::int64_t block_size, |
224 | const std::int64_t output_size, |
225 | const std::int64_t index_size, |
226 | const std::int64_t data_size, |
227 | const InType* input, |
228 | const IndexType* indices, |
229 | const OffsetType* offsets_or_lengths, |
230 | const float* weights, // optional, can be null for non-weighted sum |
231 | bool normalize_by_lengths, |
232 | OutType* out, |
233 | bool is_weight_positional = false, |
234 | bool use_offsets = true, |
235 | std::int64_t output_stride = -1, |
236 | std::int64_t input_stride = -1, |
237 | bool scale_bias_last = true, |
238 | bool no_bag = false, |
239 | bool is_bf16 = false); |
240 | |
241 | template < |
242 | typename IndexType = std::int64_t, |
243 | typename OffsetType = std::int32_t, |
244 | typename OutType = float> |
245 | FBGEMM_API bool EmbeddingSpMDMNBit_ref( |
246 | int bit_rate, |
247 | const std::int64_t block_size, |
248 | const std::int64_t output_size, |
249 | const std::int64_t index_size, |
250 | const std::int64_t data_size, |
251 | const std::uint8_t* input, |
252 | const IndexType* indices, |
253 | const OffsetType* offsets_or_lengths, |
254 | const float* weights, // optional, can be null for non-weighted sum |
255 | bool normalize_by_lengths, |
256 | OutType* out, |
257 | bool is_weight_positional = false, |
258 | bool use_offsets = true, |
259 | std::int64_t output_stride = -1, |
260 | std::int64_t input_stride = -1, |
261 | bool scale_bias_last = true); |
262 | |
263 | template < |
264 | typename IndexType = std::int64_t, |
265 | typename OffsetType = std::int32_t, |
266 | typename OutType = float> |
267 | bool EmbeddingSpMDMFP8_ref( |
268 | const int64_t block_size, |
269 | const int64_t output_size, |
270 | const int64_t index_size, |
271 | const int64_t data_size, |
272 | const uint8_t* input, |
273 | const IndexType* indices, |
274 | const OffsetType* offsets_or_lengths, |
275 | const float* weights, |
276 | bool normalize_by_lengths, |
277 | OutType* out, |
278 | bool is_weight_positional = false, |
279 | bool use_offsets = true, |
280 | int64_t output_stride = -1, |
281 | int64_t input_stride = -1, |
282 | int exponent_bits = 4, |
283 | int exponent_bias = 7); |
284 | |
285 | template < |
286 | typename InType = std::uint8_t, |
287 | typename IndexType = std::int64_t, |
288 | typename OffsetType = std::int32_t> |
289 | FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( |
290 | const std::int64_t block_size, |
291 | const std::int64_t output_size, |
292 | const std::int64_t index_size, |
293 | const std::int64_t uncompressed_data_size, |
294 | // const std::int64_t compressed_data_size, |
295 | const InType* input, |
296 | const IndexType* indices, |
297 | const std::int32_t* compressed_indices_table, |
298 | const OffsetType* offsets_or_lengths, |
299 | const float* weights, // optional, can be null for non-weighted sum |
300 | bool normalize_by_lengths, |
301 | float* out, |
302 | bool is_weight_positional = false, |
303 | bool use_offsets = true); |
304 | |
305 | template <typename IndexType = std::int64_t, typename OffsetType = std::int32_t> |
306 | FBGEMM_API bool EmbeddingSpMDMNBitRowWiseSparse_ref( |
307 | int bit_rate, |
308 | const std::int64_t block_size, |
309 | const std::int64_t output_size, |
310 | const std::int64_t index_size, |
311 | const std::int64_t uncompressed_data_size, |
312 | // const std::int64_t compressed_data_size, |
313 | const std::uint8_t* input, |
314 | const IndexType* indices, |
315 | const std::int32_t* compressed_indices_table, |
316 | const OffsetType* offsets_or_lengths, |
317 | const float* weights, // optional, can be null for non-weighted sum |
318 | bool normalize_by_lengths, |
319 | float* out, |
320 | bool is_weight_positional = false, |
321 | bool use_offsets = true); |
322 | |
323 | /** |
324 | * @param num_rows number of rows reading |
325 | * @param block_size number of parameters per rows |
326 | * @param param_size total number of parameters |
327 | * @param w input parameters |
328 | * @param g input gradients |
329 | * @param h input momentum |
330 | * @param indices indices of each row |
331 | * @param counter used for weight_decay adjusted for frequency. nullptr when |
332 | * frequency adjustment is not used. Ignored when weight_decay |
333 | * == 0 |
334 | * @param counter_halflife weight_decay is adjusted only after this number of |
335 | * iterations |
336 | */ |
337 | template <typename IndexType> |
338 | FBGEMM_API int sparse_adagrad_ref( |
339 | int num_rows, |
340 | int block_size, |
341 | std::uint64_t param_size, |
342 | float* w, |
343 | const float* g, |
344 | float* h, |
345 | const IndexType* indices, |
346 | float epsilon, |
347 | float lr, |
348 | float weight_decay = 0.f, |
349 | const double* counter = nullptr, |
350 | const int64_t counter_halflife = 0); |
351 | |
352 | /** |
353 | * @param num_rows number of rows reading |
354 | * @param block_size number of parameters per rows |
355 | * @param param_size total number of parameters |
356 | * @param w input parameters |
357 | * @param g input gradients |
358 | * @param h input momentum |
359 | * @param indices indices of each row |
360 | * @param counter used for weight_decay adjusted for frequency. nullptr when |
361 | * frequency adjustment is not used. Ignored when weight_decay |
362 | * == 0 |
363 | * @param counter_halflife weight_decay is adjusted only after this number of |
364 | * iterations |
365 | */ |
366 | template <typename IndexType> |
367 | FBGEMM_API int rowwise_sparse_adagrad_ref( |
368 | int num_rows, |
369 | int block_size, |
370 | std::uint64_t param_size, |
371 | float* w, |
372 | const float* g, |
373 | float* h, |
374 | const IndexType* indices, |
375 | float epsilon, |
376 | float lr, |
377 | float weight_decay = 0.f, |
378 | const double* counter = nullptr, |
379 | const int64_t counter_halflife = 0); |
380 | |
381 | template <typename DataType, typename IndexType, typename OffsetType> |
382 | FBGEMM_API int rowwise_sparse_adagrad_fused_ref( |
383 | std::int64_t block_size, |
384 | std::int64_t output_size, |
385 | std::int64_t index_size, |
386 | std::int64_t data_size, |
387 | DataType* w, // input/output parameters |
388 | const float* g, // inupt gradients |
389 | float* h, // input/output momentums |
390 | const IndexType* indices, |
391 | const OffsetType* offsets_or_lengths, |
392 | float epsilon, |
393 | float lr, |
394 | bool use_offsets = true, |
395 | bool use_stochastic_rounding = true, // For DataType=float16 |
396 | int emu_vector_size = 8, |
397 | std::int64_t grad_stride = -1); |
398 | |
399 | template <typename IndexType> |
400 | FBGEMM_API void compressed_indices_remap_ref( |
401 | std::int32_t offsets_len, |
402 | const IndexType* indices, |
403 | const int32_t* compressed_indices_mapping, |
404 | const IndexType* offsets, |
405 | const float* weights, // optional, can be null, |
406 | IndexType* out_indices, |
407 | IndexType* out_offsets, |
408 | float* out_weights); |
409 | |
410 | template <typename T> |
411 | float convert_to_float_ref(T src, bool is_bf16 = false) { |
412 | float f_value; |
413 | if (std::is_same<T, uint16_t>::value) { |
414 | f_value = is_bf16 ? cpu_bf162float(src) : cpu_half2float(src); |
415 | } else { |
416 | f_value = src; |
417 | } |
418 | return f_value; |
419 | } |
420 | |
421 | template <typename T> |
422 | T convert_from_float_ref(float src, bool is_bf16 = false) { |
423 | T o_value; |
424 | if (std::is_same<T, uint16_t>::value) { |
425 | o_value = is_bf16 ? cpu_float2bfloat16(src) : cpu_float2half_rn(src); |
426 | } else { |
427 | o_value = src; |
428 | } |
429 | return o_value; |
430 | } |
431 | |
432 | } // namespace fbgemm |
433 | |