1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#pragma once
8#include "./ExecuteKernel.h"
9
10namespace fbgemm {
11
12/**
13 * @brief Execute Engine of uint 8 and int8 matrix
14 * multiplication for the macro-kernel and output processing. ExecuteKernel is a
15 * derived class of CodeGenBase.
16 */
17template <typename packingAMatrix, typename cT, typename processOutputType>
18class ExecuteKernel<
19 packingAMatrix,
20 PackBMatrix<int8_t, typename packingAMatrix::accType>,
21 cT,
22 processOutputType>
23 : public CodeGenBase<
24 uint8_t,
25 int8_t,
26 int32_t,
27 typename packingAMatrix::accType> {
28 public:
29 using BaseType =
30 CodeGenBase<uint8_t, int8_t, int32_t, typename packingAMatrix::accType>;
31 /**
32 * @brief Constructor for initializing the parameters for macro-kernel and
33 * output processing type.
34 */
35 ExecuteKernel(
36 PackMatrix<packingAMatrix, uint8_t, typename packingAMatrix::accType>&
37 packA,
38 PackMatrix<
39 PackBMatrix<int8_t, typename packingAMatrix::accType>,
40 int8_t,
41 typename packingAMatrix::accType>& packB,
42 cT* matC,
43 int32_t* C_buffer,
44 int32_t ldc,
45 const processOutputType& outputProcess,
46 thread_type_t th_info,
47 const BlockingFactors* params = nullptr);
48 void execute(int kBlock);
49
50 private:
51 PackMatrix<packingAMatrix, uint8_t, typename packingAMatrix::accType>&
52 packedA_; ///< Packed uint8 block of matrix A.
53 PackMatrix<
54 PackBMatrix<int8_t, typename packingAMatrix::accType>,
55 int8_t,
56 typename packingAMatrix::accType>& packedB_; ///< Packed int8 matrix B.
57 cT* matC_; ///< Output for matrix C.
58 int32_t* C_buffer_; ///< the accumulation buffer for matrix C.
59 int32_t ldc_; ///< the leading dimension of matrix C.
60 const processOutputType& outputProcess_; ///< output processing function for
61 ///< matrix C in the macro-kernel.
62 thread_type_t
63 th_info_; ///<< the thread partition information (thread id and the number
64 ///< of threads across the group, m, n dimensions.
65 int mbSize_; ///< block size in the m dimension.
66 int nbSize_; ///< block size in the n dimension.
67 int nrMinSize_; ///< minimum register size in the n dimension.
68 int nrSize_; ///< register size in the n dimension.
69};
70
71} // namespace fbgemm
72