1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #pragma once |
8 | |
9 | #include <array> |
10 | #include <cstdint> |
11 | |
12 | namespace fbgemm { |
13 | |
14 | class GenI8Depthwise { |
15 | public: |
16 | using jit_kernel_signature = void (*)( |
17 | const std::uint8_t* a, |
18 | const std::int8_t* b, |
19 | std::int32_t* c, |
20 | std::int32_t* a_sum, // row_wise sum of A |
21 | int h, |
22 | int w, |
23 | int ic, // the number of input channels == the number of groups |
24 | const int* mask, |
25 | int A_zero_point); |
26 | |
27 | jit_kernel_signature getOrCreate( |
28 | int D, // dimension |
29 | std::array<int, 3> F, // filter size (K_T, K_H, K_W) |
30 | int oc_per_g, // the number of output channels per group |
31 | bool compute_a_sum, |
32 | int remainder, // the number of channels in the remainder loop |
33 | int prev_skip, |
34 | int next_skip, |
35 | int top_skip, |
36 | int bottom_skip, |
37 | int left_skip, |
38 | int right_skip); |
39 | }; |
40 | |
41 | } // namespace fbgemm |
42 | |