1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #pragma once |
8 | #include "./ExecuteKernel.h" |
9 | |
10 | namespace fbgemm { |
11 | |
12 | /** |
13 | * @brief Execute Engine of uint 8 and int8 matrix |
14 | * multiplication for the macro-kernel and output processing. ExecuteKernel is a |
15 | * derived class of CodeGenBase. |
16 | */ |
17 | template <typename packingAMatrix, typename cT, typename processOutputType> |
18 | class ExecuteKernel< |
19 | packingAMatrix, |
20 | PackBMatrix<int8_t, typename packingAMatrix::accType>, |
21 | cT, |
22 | processOutputType> |
23 | : public CodeGenBase< |
24 | uint8_t, |
25 | int8_t, |
26 | int32_t, |
27 | typename packingAMatrix::accType> { |
28 | public: |
29 | using BaseType = |
30 | CodeGenBase<uint8_t, int8_t, int32_t, typename packingAMatrix::accType>; |
31 | /** |
32 | * @brief Constructor for initializing the parameters for macro-kernel and |
33 | * output processing type. |
34 | */ |
35 | ExecuteKernel( |
36 | PackMatrix<packingAMatrix, uint8_t, typename packingAMatrix::accType>& |
37 | packA, |
38 | PackMatrix< |
39 | PackBMatrix<int8_t, typename packingAMatrix::accType>, |
40 | int8_t, |
41 | typename packingAMatrix::accType>& packB, |
42 | cT* matC, |
43 | int32_t* C_buffer, |
44 | int32_t ldc, |
45 | const processOutputType& outputProcess, |
46 | thread_type_t th_info, |
47 | const BlockingFactors* params = nullptr); |
48 | void execute(int kBlock); |
49 | |
50 | private: |
51 | PackMatrix<packingAMatrix, uint8_t, typename packingAMatrix::accType>& |
52 | packedA_; ///< Packed uint8 block of matrix A. |
53 | PackMatrix< |
54 | PackBMatrix<int8_t, typename packingAMatrix::accType>, |
55 | int8_t, |
56 | typename packingAMatrix::accType>& packedB_; ///< Packed int8 matrix B. |
57 | cT* matC_; ///< Output for matrix C. |
58 | int32_t* C_buffer_; ///< the accumulation buffer for matrix C. |
59 | int32_t ldc_; ///< the leading dimension of matrix C. |
60 | const processOutputType& outputProcess_; ///< output processing function for |
61 | ///< matrix C in the macro-kernel. |
62 | thread_type_t |
63 | th_info_; ///<< the thread partition information (thread id and the number |
64 | ///< of threads across the group, m, n dimensions. |
65 | int mbSize_; ///< block size in the m dimension. |
66 | int nbSize_; ///< block size in the n dimension. |
67 | int nrMinSize_; ///< minimum register size in the n dimension. |
68 | int nrSize_; ///< register size in the n dimension. |
69 | }; |
70 | |
71 | } // namespace fbgemm |
72 | |