1/*******************************************************************************
2* Copyright 2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_GEMM_GEMM_UTILS_HPP
18#define CPU_X64_GEMM_GEMM_UTILS_HPP
19
20#include <tuple>
21
22#include "common/dnnl_thread.hpp"
23#include "common/dnnl_traits.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/x64/gemm/gemm_pack_storage.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32namespace gemm_utils {
33
34static inline std::tuple<int, int> calc_nthr_2d(int nthrs, dim_t m, dim_t n,
35 dim_t block_m, dim_t block_n, dim_t small_m, dim_t small_n,
36 dim_t &thread_m, dim_t &thread_n) {
37
38 int nthr_m = utils::div_up(m, block_m);
39 int nthr_n = utils::div_up(n, block_n);
40
41 if (nthr_m < 1) nthr_m = 1;
42 if (nthr_n < 1) nthr_n = 1;
43
44 float ratio_float = (float)nthr_m / nthr_n;
45
46 int ratio = 0;
47 if (nthr_m > nthr_n)
48 ratio = (int)ratio_float;
49 else
50 ratio = (int)(1. / ratio_float);
51
52 // scale down nthr_m and nthr_n if they are too large
53 while (nthr_m * nthr_n > 4 * nthrs) {
54 nthr_m /= 2;
55 nthr_n /= 2;
56 }
57
58 if (nthr_m < 1) nthr_m = 1;
59 if (nthr_n < 1) nthr_n = 1;
60
61 // Simple partition reduction
62 int counter = 0;
63 while (nthr_m * nthr_n > nthrs) {
64 if (nthr_m > nthr_n) {
65 if (counter < ratio)
66 nthr_m--;
67 else {
68 nthr_n--;
69 counter = -1;
70 }
71 } else {
72 if (counter < ratio)
73 nthr_n--;
74 else {
75 nthr_m--;
76 counter = -1;
77 }
78 }
79 counter++;
80 }
81
82 // Simple partition increment
83 counter = 0;
84 while (nthr_m * nthr_n < 0.95 * nthrs) {
85 if (nthr_m > nthr_n) {
86 if (counter < ratio)
87 nthr_m++;
88 else {
89 nthr_n++;
90 counter = -1;
91 }
92 } else {
93 if (counter < ratio)
94 nthr_n++;
95 else {
96 nthr_m++;
97 counter = -1;
98 }
99 }
100 counter++;
101 }
102
103 // if nothing works out, then this should work
104 if ((nthr_m * nthr_n > nthrs)) {
105
106 if (nthr_m <= nthr_n) {
107 nthr_m = (int)sqrt((double)nthrs);
108 if (nthr_m > utils::div_up(m, small_m))
109 nthr_m = utils::div_up(m, small_m);
110 nthr_n = nthrs / nthr_m;
111
112 while ((nthr_m > 1) && (nthr_m * nthr_n != nthrs)) {
113 nthr_m--;
114 nthr_n = nthrs / nthr_m;
115 }
116 } else {
117 nthr_n = (int)sqrt((double)nthrs);
118 if (nthr_n > utils::div_up(n, small_n))
119 nthr_n = utils::div_up(n, small_n);
120 nthr_m = nthrs / nthr_n;
121
122 while ((nthr_n > 1) && (nthr_m * nthr_n != nthrs)) {
123 nthr_n--;
124 nthr_m = nthrs / nthr_n;
125 }
126 }
127 }
128
129 thread_m = utils::div_up(m, nthr_m) + small_m - 1;
130 thread_n = utils::div_up(n, nthr_n) + small_n - 1;
131 thread_m -= thread_m % small_m;
132 thread_n -= thread_n % small_n;
133
134 if (thread_m * nthr_m > m) nthr_m = utils::div_up(m, thread_m);
135 if (thread_n * nthr_n > n) nthr_n = utils::div_up(n, thread_n);
136
137 return std::make_tuple(nthr_m, nthr_n);
138}
139
140template <typename T>
141static inline dim_t get_ld_padd(const dim_t x) {
142 return x != 1 ? utils::rnd_up(x, 2048 / sizeof(T)) + (64 / sizeof(T)) : 1;
143}
144
145template <typename mat_t, typename acc_t>
146void prep_gemm_pack(bool do_a, int is_trans, dim_t nrows, dim_t ncols,
147 gemm_pack_storage_t *pack_dst) {
148
149 auto ld = !is_trans ? get_ld_padd<mat_t>(nrows) : get_ld_padd<mat_t>(ncols);
150 auto td = !is_trans ? ncols : nrows;
151
152 // TODO Do we need to use only one thread?
153 pack_dst->which() = do_a ? matrix_id::a : matrix_id::b;
154 pack_dst->setup(1);
155 pack_dst->threading().copy = copy_type::no_copy;
156 pack_dst->threading().nthrs_m = 1;
157 pack_dst->threading().nthrs_n = 1;
158 pack_dst->threading().nthrs_k = 1;
159 pack_dst->set_nocopy(0, is_trans, ld, td);
160 pack_dst->finalize<mat_t, acc_t>();
161}
162
163template <typename T>
164dnnl_status_t pack_no_copy(const T *src, dim_t ld_src, dim_t nrows, dim_t ncols,
165 int trans_src, float alpha, gemm_pack_storage_t *dst_pack) {
166
167 auto dst = dst_pack->matrix<T>(0);
168 int trans_dst;
169 dim_t nrows_dst, ncols_dst;
170 dim_t ld_dst, td_dst;
171
172 constexpr bool is_f32 = data_traits<T>::data_type == data_type::f32;
173
174 if (!dst_pack->get_nocopy(0, trans_dst, ld_dst, td_dst))
175 return dnnl_invalid_arguments;
176
177 if (!trans_dst) {
178 nrows_dst = nrows;
179 ncols_dst = ncols;
180 } else {
181 nrows_dst = ncols;
182 ncols_dst = nrows;
183 }
184
185 if (trans_src == trans_dst) {
186 parallel_nd(ncols_dst, [=](dim_t j) {
187 auto src_col = src + j * ld_src;
188 auto dst_col = dst + j * ld_dst;
189
190 PRAGMA_OMP_SIMD()
191 for (dim_t i = 0; i < nrows_dst; i++)
192 if (is_f32)
193 dst_col[i] = alpha * src_col[i];
194 else
195 dst_col[i] = src_col[i];
196 });
197 } else {
198 // Naive code for now.
199 parallel_nd(ncols_dst, [=](dim_t j) {
200 auto src_col = src + j;
201 auto dst_col = dst + j * ld_dst;
202
203 PRAGMA_OMP_SIMD()
204 for (dim_t i = 0; i < nrows_dst; i++)
205 if (is_f32)
206 dst_col[i] = alpha * src_col[i * ld_src];
207 else
208 dst_col[i] = src_col[i * ld_src];
209 });
210 }
211
212 return dnnl_success;
213}
214
215} // namespace gemm_utils
216} // namespace x64
217} // namespace cpu
218} // namespace impl
219} // namespace dnnl
220
221#endif // CPU_X64_GEMM_GEMM_UTILS_HPP
222