1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#pragma once
8
9#include <algorithm>
10#include <cstdint>
11
12#include "fbgemm/ConvUtils.h"
13#include "fbgemm/FbgemmI8Spmdm.h"
14#include "fbgemm/Types.h"
15
16namespace fbgemm {
17
18/**
19 * @brief Reference implementation of requantization step.
20 * int32 multiplier
21 * @param bias can be nullptr
22 */
23FBGEMM_API void requantize_u8acc32_ref(
24 int M,
25 int N,
26 int ld,
27 const std::int32_t* inp,
28 std::uint8_t* out,
29 std::int32_t C_multiplier,
30 std::int32_t C_right_shift,
31 std::int32_t C_zero_point,
32 std::int32_t A_zero_point,
33 std::int32_t B_zero_point,
34 const std::int32_t* row_offsets,
35 const std::int32_t* col_offsets,
36 const std::int32_t* bias,
37 bool fuse_relu = false);
38
39/**
40 * @brief Reference implementation of requantization step.
41 * float multiplier
42 * @param bias can be nullptr
43 * @param ncols_per_quant_group the number of columns share the same
44 * quantization parameter.
45 * ncols_per_quant_group == N : per-tensor quantization
46 * ncols_per_quant_group == N / groups : per-group quantization
47 * ncols_per_quant_group == 1 : per-channel quantization
48 */
49FBGEMM_API void requantize_u8acc32_ref(
50 int M,
51 int N,
52 int ld,
53 const std::int32_t* inp,
54 std::uint8_t* out,
55 const float* C_multiplier,
56 std::int32_t C_zero_point,
57 std::int32_t A_zero_point,
58 const std::int32_t* B_zero_point,
59 const std::int32_t* row_offsets,
60 const std::int32_t* col_offsets,
61 const std::int32_t* bias,
62 int ncols_per_quant_group,
63 bool fuse_relu = false);
64
65/**
66 * @brief Reference implementation of matrix multiply with uint8 for A,
67 * int8 for B, and 32-bit accumulation.
68 */
69FBGEMM_API void matmul_u8i8acc32_ref(
70 int M,
71 int N,
72 int K,
73 int lda,
74 int ldb,
75 int ldc,
76 const std::uint8_t* Aint8,
77 const std::int8_t* Bint8,
78 std::int32_t* Cint32);
79
80/**
81 * @brief Reference implementation of matrix multiply with uint 8 for A,
82 * int8 for B, and 16-bit accumulation.
83 */
84FBGEMM_API void matmul_u8i8acc16_ref(
85 int M,
86 int N,
87 int K,
88 int lda,
89 int ldb,
90 int ldc,
91 int brow,
92 const std::uint8_t* Aint8,
93 const std::int8_t* Bint8,
94 std::int32_t* Cint32);
95
96/**
97 * @brief Reference implementation of cblas_sgemm in MKL/BLAS.
98 */
99FBGEMM_API void cblas_sgemm_ref(
100 const matrix_op_t transa,
101 const matrix_op_t transb,
102 const int m,
103 const int n,
104 const int k,
105 float alpha,
106 const float* Afp32,
107 int lda,
108 const float* Bfp32,
109 int ldb,
110 float beta,
111 float* Cfp32,
112 int ldc);
113
114FBGEMM_API void cblas_gemm_i64_i64acc_ref(
115 matrix_op_t transa,
116 matrix_op_t transb,
117 int M,
118 int N,
119 int K,
120 const std::int64_t* A,
121 int lda,
122 const std::int64_t* B,
123 int ldb,
124 bool accumulate,
125 std::int64_t* C,
126 int ldc);
127
128/**
129 * @brief Reference implementation to compute row_offsets (sums of rows of A).
130 */
131FBGEMM_API void row_offsets_u8acc32_ref(
132 int M,
133 int K,
134 int ld,
135 const std::uint8_t* Aint8,
136 std::int32_t* row_offsets);
137
138/**
139 * @brief Reference implementation to compute adjusted col_offsets (sum of
140 * columns of B and adjusted with B_zero_point)
141 *
142 * @param ncols_per_quant_group see ncols_per_quant_group in
143 * requantize_u8acc32_ref
144 */
145FBGEMM_API void col_offsets_with_zero_pt_s8acc32_ref(
146 int K,
147 int N,
148 int ld,
149 const std::int8_t* Bint8,
150 const std::int32_t* B_zero_point,
151 std::int32_t* col_offsets,
152 int ncols_per_quant_group);
153
154/**
155 * @brief Reference implementation of SPMDM (sparse matrix times dense matrix).
156 *
157 * @param groups when > 1, for gth group, we multiply
158 * A[:,g*(A.ncols/groups):(g+1)*(A.ncols/groups)] sub-matrix with
159 * B[:,g*(B.ncols/groups):(g+1)*(B.ncols/groups)] sub-matrix .
160 */
161FBGEMM_API void spmdm_ref(
162 int M,
163 const std::uint8_t* A,
164 int lda,
165 CompressedSparseColumn& B,
166 bool accumulation,
167 std::int32_t* C,
168 int ldc,
169 int groups = 1);
170
171/*
172 * @brief Trim a 32-bit integer to a 16-bit integer.
173 */
174int32_t clip_16bit(int32_t x);
175
176/*
177 * @brief Reference implementation of convolution operation.
178 * The activations A are assumed to be in NHiWiC format.
179 * The filters B are assumed to be in RSCK format.
180 * The output C is assumed to be in NHoWoC format.
181 */
182template <int SPATIAL_DIM = 2>
183FBGEMM_API void conv_ref(
184 const conv_param_t<SPATIAL_DIM>& conv_p,
185 const std::uint8_t* A,
186 std::int32_t A_zero_point,
187 const std::int8_t* B,
188 std::int32_t* C);
189
190/*
191 * @brief Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format.
192 */
193template <int SPATIAL_DIM = 2>
194FBGEMM_API void transposeConvWeights(
195 const conv_param_t<SPATIAL_DIM>& conv_p,
196 const std::int8_t* src,
197 std::int8_t* dest);
198
199/*
200 * @brief Reference implementation of im2col operation.
201 *
202 * For 2D:
203 * The input A is assumed to be in NHiWiC format.
204 * The output A is assumed to be in NHoWoRSC format.
205 *
206 * For 3D:
207 * The input A is assumed to be in NTiHiWiC format.
208 * The output A is assumed to be in NToHoWoK0K1K2C format.
209 */
210template <int SPATIAL_DIM = 2>
211FBGEMM_API void im2col_ref(
212 const conv_param_t<SPATIAL_DIM>& conv_p,
213 const std::uint8_t* A,
214 std::int32_t A_zero_point,
215 std::uint8_t* Ao);
216
217template <
218 typename InType = std::uint8_t,
219 typename IndexType = std::int64_t,
220 typename OffsetType = std::int32_t,
221 typename OutType = float>
222FBGEMM_API bool EmbeddingSpMDM_ref(
223 const std::int64_t block_size,
224 const std::int64_t output_size,
225 const std::int64_t index_size,
226 const std::int64_t data_size,
227 const InType* input,
228 const IndexType* indices,
229 const OffsetType* offsets_or_lengths,
230 const float* weights, // optional, can be null for non-weighted sum
231 bool normalize_by_lengths,
232 OutType* out,
233 bool is_weight_positional = false,
234 bool use_offsets = true,
235 std::int64_t output_stride = -1,
236 std::int64_t input_stride = -1,
237 bool scale_bias_last = true,
238 bool no_bag = false,
239 bool is_bf16 = false);
240
241template <
242 typename IndexType = std::int64_t,
243 typename OffsetType = std::int32_t,
244 typename OutType = float>
245FBGEMM_API bool EmbeddingSpMDMNBit_ref(
246 int bit_rate,
247 const std::int64_t block_size,
248 const std::int64_t output_size,
249 const std::int64_t index_size,
250 const std::int64_t data_size,
251 const std::uint8_t* input,
252 const IndexType* indices,
253 const OffsetType* offsets_or_lengths,
254 const float* weights, // optional, can be null for non-weighted sum
255 bool normalize_by_lengths,
256 OutType* out,
257 bool is_weight_positional = false,
258 bool use_offsets = true,
259 std::int64_t output_stride = -1,
260 std::int64_t input_stride = -1,
261 bool scale_bias_last = true);
262
263template <
264 typename IndexType = std::int64_t,
265 typename OffsetType = std::int32_t,
266 typename OutType = float>
267bool EmbeddingSpMDMFP8_ref(
268 const int64_t block_size,
269 const int64_t output_size,
270 const int64_t index_size,
271 const int64_t data_size,
272 const uint8_t* input,
273 const IndexType* indices,
274 const OffsetType* offsets_or_lengths,
275 const float* weights,
276 bool normalize_by_lengths,
277 OutType* out,
278 bool is_weight_positional = false,
279 bool use_offsets = true,
280 int64_t output_stride = -1,
281 int64_t input_stride = -1,
282 int exponent_bits = 4,
283 int exponent_bias = 7);
284
285template <
286 typename InType = std::uint8_t,
287 typename IndexType = std::int64_t,
288 typename OffsetType = std::int32_t>
289FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref(
290 const std::int64_t block_size,
291 const std::int64_t output_size,
292 const std::int64_t index_size,
293 const std::int64_t uncompressed_data_size,
294 // const std::int64_t compressed_data_size,
295 const InType* input,
296 const IndexType* indices,
297 const std::int32_t* compressed_indices_table,
298 const OffsetType* offsets_or_lengths,
299 const float* weights, // optional, can be null for non-weighted sum
300 bool normalize_by_lengths,
301 float* out,
302 bool is_weight_positional = false,
303 bool use_offsets = true);
304
305template <typename IndexType = std::int64_t, typename OffsetType = std::int32_t>
306FBGEMM_API bool EmbeddingSpMDMNBitRowWiseSparse_ref(
307 int bit_rate,
308 const std::int64_t block_size,
309 const std::int64_t output_size,
310 const std::int64_t index_size,
311 const std::int64_t uncompressed_data_size,
312 // const std::int64_t compressed_data_size,
313 const std::uint8_t* input,
314 const IndexType* indices,
315 const std::int32_t* compressed_indices_table,
316 const OffsetType* offsets_or_lengths,
317 const float* weights, // optional, can be null for non-weighted sum
318 bool normalize_by_lengths,
319 float* out,
320 bool is_weight_positional = false,
321 bool use_offsets = true);
322
323/**
324 * @param num_rows number of rows reading
325 * @param block_size number of parameters per rows
326 * @param param_size total number of parameters
327 * @param w input parameters
328 * @param g input gradients
329 * @param h input momentum
330 * @param indices indices of each row
331 * @param counter used for weight_decay adjusted for frequency. nullptr when
332 * frequency adjustment is not used. Ignored when weight_decay
333 * == 0
334 * @param counter_halflife weight_decay is adjusted only after this number of
335 * iterations
336 */
337template <typename IndexType>
338FBGEMM_API int sparse_adagrad_ref(
339 int num_rows,
340 int block_size,
341 std::uint64_t param_size,
342 float* w,
343 const float* g,
344 float* h,
345 const IndexType* indices,
346 float epsilon,
347 float lr,
348 float weight_decay = 0.f,
349 const double* counter = nullptr,
350 const int64_t counter_halflife = 0);
351
352/**
353 * @param num_rows number of rows reading
354 * @param block_size number of parameters per rows
355 * @param param_size total number of parameters
356 * @param w input parameters
357 * @param g input gradients
358 * @param h input momentum
359 * @param indices indices of each row
360 * @param counter used for weight_decay adjusted for frequency. nullptr when
361 * frequency adjustment is not used. Ignored when weight_decay
362 * == 0
363 * @param counter_halflife weight_decay is adjusted only after this number of
364 * iterations
365 */
366template <typename IndexType>
367FBGEMM_API int rowwise_sparse_adagrad_ref(
368 int num_rows,
369 int block_size,
370 std::uint64_t param_size,
371 float* w,
372 const float* g,
373 float* h,
374 const IndexType* indices,
375 float epsilon,
376 float lr,
377 float weight_decay = 0.f,
378 const double* counter = nullptr,
379 const int64_t counter_halflife = 0);
380
381template <typename DataType, typename IndexType, typename OffsetType>
382FBGEMM_API int rowwise_sparse_adagrad_fused_ref(
383 std::int64_t block_size,
384 std::int64_t output_size,
385 std::int64_t index_size,
386 std::int64_t data_size,
387 DataType* w, // input/output parameters
388 const float* g, // inupt gradients
389 float* h, // input/output momentums
390 const IndexType* indices,
391 const OffsetType* offsets_or_lengths,
392 float epsilon,
393 float lr,
394 bool use_offsets = true,
395 bool use_stochastic_rounding = true, // For DataType=float16
396 int emu_vector_size = 8,
397 std::int64_t grad_stride = -1);
398
399template <typename IndexType>
400FBGEMM_API void compressed_indices_remap_ref(
401 std::int32_t offsets_len,
402 const IndexType* indices,
403 const int32_t* compressed_indices_mapping,
404 const IndexType* offsets,
405 const float* weights, // optional, can be null,
406 IndexType* out_indices,
407 IndexType* out_offsets,
408 float* out_weights);
409
410template <typename T>
411float convert_to_float_ref(T src, bool is_bf16 = false) {
412 float f_value;
413 if (std::is_same<T, uint16_t>::value) {
414 f_value = is_bf16 ? cpu_bf162float(src) : cpu_half2float(src);
415 } else {
416 f_value = src;
417 }
418 return f_value;
419}
420
421template <typename T>
422T convert_from_float_ref(float src, bool is_bf16 = false) {
423 T o_value;
424 if (std::is_same<T, uint16_t>::value) {
425 o_value = is_bf16 ? cpu_float2bfloat16(src) : cpu_float2half_rn(src);
426 } else {
427 o_value = src;
428 }
429 return o_value;
430}
431
432} // namespace fbgemm
433