1/*******************************************************************************
2* Copyright 2018-2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_GEMM_F32_GEMM_UTILS_F32_HPP
18#define CPU_GEMM_F32_GEMM_UTILS_F32_HPP
19
20#include <cstddef>
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25
26namespace gemm_utils {
27template <typename T, bool isTransA, bool isTransB>
28struct gemm_traits {};
29
30template <bool isTransA, bool isTransB>
31struct gemm_traits<double, isTransA, isTransB> {
32 static constexpr dim_t m = 8;
33 static constexpr dim_t n = 6;
34 static constexpr dim_t BM = 4032;
35 static constexpr dim_t BN = isTransA ? 96 : 192;
36 static constexpr dim_t BK = isTransB ? 96 : 512;
37};
38
39template <bool isTransA, bool isTransB>
40struct gemm_traits<float, isTransA, isTransB> {
41 static constexpr dim_t m = 16;
42 static constexpr dim_t n = 6;
43 static constexpr dim_t BM = 4032;
44 static constexpr dim_t BN = isTransA ? 96 : 48;
45 static constexpr dim_t BK = isTransB ? 96 : 256;
46};
47
48template <typename T>
49using unroll_factor = gemm_traits<T, false, false>;
50
51template <typename data_t>
52void sum_two_matrices(dim_t m, dim_t n, data_t *__restrict p_src, dim_t ld_src,
53 data_t *__restrict p_dst, dim_t ld_dst);
54
55void calc_nthr_nocopy_avx512_common(dim_t m, dim_t n, dim_t k, int nthrs,
56 int *nthrs_m, int *nthrs_n, int *nthrs_k, dim_t *BM, dim_t *BN,
57 dim_t *BK);
58
59void calc_nthr_nocopy_avx(dim_t m, dim_t n, dim_t k, int nthrs, int *nthrs_m,
60 int *nthrs_n, int *nthrs_k, dim_t *BM, dim_t *BN, dim_t *BK);
61
62void partition_unit_diff(
63 int ithr, int nthr, dim_t n, dim_t *t_offset, dim_t *t_block);
64}; // namespace gemm_utils
65
66} // namespace cpu
67} // namespace impl
68} // namespace dnnl
69#endif // CPU_GEMM_F32_GEMM_UTILS_F32_HPP
70