1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#pragma once
8
9#include <array>
10#include <cstdint>
11
12namespace fbgemm {
13
14class GenI8Depthwise {
15 public:
16 using jit_kernel_signature = void (*)(
17 const std::uint8_t* a,
18 const std::int8_t* b,
19 std::int32_t* c,
20 std::int32_t* a_sum, // row_wise sum of A
21 int h,
22 int w,
23 int ic, // the number of input channels == the number of groups
24 const int* mask,
25 int A_zero_point);
26
27 jit_kernel_signature getOrCreate(
28 int D, // dimension
29 std::array<int, 3> F, // filter size (K_T, K_H, K_W)
30 int oc_per_g, // the number of output channels per group
31 bool compute_a_sum,
32 int remainder, // the number of channels in the remainder loop
33 int prev_skip,
34 int next_skip,
35 int top_skip,
36 int bottom_skip,
37 int left_skip,
38 int right_skip);
39};
40
41} // namespace fbgemm
42