1 | /******************************************************************************* |
2 | * Copyright 2020 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_GEMM_GEMM_UTILS_HPP |
18 | #define CPU_X64_GEMM_GEMM_UTILS_HPP |
19 | |
20 | #include <tuple> |
21 | |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/dnnl_traits.hpp" |
24 | #include "common/utils.hpp" |
25 | |
26 | #include "cpu/x64/gemm/gemm_pack_storage.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace x64 { |
32 | namespace gemm_utils { |
33 | |
34 | static inline std::tuple<int, int> calc_nthr_2d(int nthrs, dim_t m, dim_t n, |
35 | dim_t block_m, dim_t block_n, dim_t small_m, dim_t small_n, |
36 | dim_t &thread_m, dim_t &thread_n) { |
37 | |
38 | int nthr_m = utils::div_up(m, block_m); |
39 | int nthr_n = utils::div_up(n, block_n); |
40 | |
41 | if (nthr_m < 1) nthr_m = 1; |
42 | if (nthr_n < 1) nthr_n = 1; |
43 | |
44 | float ratio_float = (float)nthr_m / nthr_n; |
45 | |
46 | int ratio = 0; |
47 | if (nthr_m > nthr_n) |
48 | ratio = (int)ratio_float; |
49 | else |
50 | ratio = (int)(1. / ratio_float); |
51 | |
52 | // scale down nthr_m and nthr_n if they are too large |
53 | while (nthr_m * nthr_n > 4 * nthrs) { |
54 | nthr_m /= 2; |
55 | nthr_n /= 2; |
56 | } |
57 | |
58 | if (nthr_m < 1) nthr_m = 1; |
59 | if (nthr_n < 1) nthr_n = 1; |
60 | |
61 | // Simple partition reduction |
62 | int counter = 0; |
63 | while (nthr_m * nthr_n > nthrs) { |
64 | if (nthr_m > nthr_n) { |
65 | if (counter < ratio) |
66 | nthr_m--; |
67 | else { |
68 | nthr_n--; |
69 | counter = -1; |
70 | } |
71 | } else { |
72 | if (counter < ratio) |
73 | nthr_n--; |
74 | else { |
75 | nthr_m--; |
76 | counter = -1; |
77 | } |
78 | } |
79 | counter++; |
80 | } |
81 | |
82 | // Simple partition increment |
83 | counter = 0; |
84 | while (nthr_m * nthr_n < 0.95 * nthrs) { |
85 | if (nthr_m > nthr_n) { |
86 | if (counter < ratio) |
87 | nthr_m++; |
88 | else { |
89 | nthr_n++; |
90 | counter = -1; |
91 | } |
92 | } else { |
93 | if (counter < ratio) |
94 | nthr_n++; |
95 | else { |
96 | nthr_m++; |
97 | counter = -1; |
98 | } |
99 | } |
100 | counter++; |
101 | } |
102 | |
103 | // if nothing works out, then this should work |
104 | if ((nthr_m * nthr_n > nthrs)) { |
105 | |
106 | if (nthr_m <= nthr_n) { |
107 | nthr_m = (int)sqrt((double)nthrs); |
108 | if (nthr_m > utils::div_up(m, small_m)) |
109 | nthr_m = utils::div_up(m, small_m); |
110 | nthr_n = nthrs / nthr_m; |
111 | |
112 | while ((nthr_m > 1) && (nthr_m * nthr_n != nthrs)) { |
113 | nthr_m--; |
114 | nthr_n = nthrs / nthr_m; |
115 | } |
116 | } else { |
117 | nthr_n = (int)sqrt((double)nthrs); |
118 | if (nthr_n > utils::div_up(n, small_n)) |
119 | nthr_n = utils::div_up(n, small_n); |
120 | nthr_m = nthrs / nthr_n; |
121 | |
122 | while ((nthr_n > 1) && (nthr_m * nthr_n != nthrs)) { |
123 | nthr_n--; |
124 | nthr_m = nthrs / nthr_n; |
125 | } |
126 | } |
127 | } |
128 | |
129 | thread_m = utils::div_up(m, nthr_m) + small_m - 1; |
130 | thread_n = utils::div_up(n, nthr_n) + small_n - 1; |
131 | thread_m -= thread_m % small_m; |
132 | thread_n -= thread_n % small_n; |
133 | |
134 | if (thread_m * nthr_m > m) nthr_m = utils::div_up(m, thread_m); |
135 | if (thread_n * nthr_n > n) nthr_n = utils::div_up(n, thread_n); |
136 | |
137 | return std::make_tuple(nthr_m, nthr_n); |
138 | } |
139 | |
140 | template <typename T> |
141 | static inline dim_t get_ld_padd(const dim_t x) { |
142 | return x != 1 ? utils::rnd_up(x, 2048 / sizeof(T)) + (64 / sizeof(T)) : 1; |
143 | } |
144 | |
145 | template <typename mat_t, typename acc_t> |
146 | void prep_gemm_pack(bool do_a, int is_trans, dim_t nrows, dim_t ncols, |
147 | gemm_pack_storage_t *pack_dst) { |
148 | |
149 | auto ld = !is_trans ? get_ld_padd<mat_t>(nrows) : get_ld_padd<mat_t>(ncols); |
150 | auto td = !is_trans ? ncols : nrows; |
151 | |
152 | // TODO Do we need to use only one thread? |
153 | pack_dst->which() = do_a ? matrix_id::a : matrix_id::b; |
154 | pack_dst->setup(1); |
155 | pack_dst->threading().copy = copy_type::no_copy; |
156 | pack_dst->threading().nthrs_m = 1; |
157 | pack_dst->threading().nthrs_n = 1; |
158 | pack_dst->threading().nthrs_k = 1; |
159 | pack_dst->set_nocopy(0, is_trans, ld, td); |
160 | pack_dst->finalize<mat_t, acc_t>(); |
161 | } |
162 | |
163 | template <typename T> |
164 | dnnl_status_t pack_no_copy(const T *src, dim_t ld_src, dim_t nrows, dim_t ncols, |
165 | int trans_src, float alpha, gemm_pack_storage_t *dst_pack) { |
166 | |
167 | auto dst = dst_pack->matrix<T>(0); |
168 | int trans_dst; |
169 | dim_t nrows_dst, ncols_dst; |
170 | dim_t ld_dst, td_dst; |
171 | |
172 | constexpr bool is_f32 = data_traits<T>::data_type == data_type::f32; |
173 | |
174 | if (!dst_pack->get_nocopy(0, trans_dst, ld_dst, td_dst)) |
175 | return dnnl_invalid_arguments; |
176 | |
177 | if (!trans_dst) { |
178 | nrows_dst = nrows; |
179 | ncols_dst = ncols; |
180 | } else { |
181 | nrows_dst = ncols; |
182 | ncols_dst = nrows; |
183 | } |
184 | |
185 | if (trans_src == trans_dst) { |
186 | parallel_nd(ncols_dst, [=](dim_t j) { |
187 | auto src_col = src + j * ld_src; |
188 | auto dst_col = dst + j * ld_dst; |
189 | |
190 | PRAGMA_OMP_SIMD() |
191 | for (dim_t i = 0; i < nrows_dst; i++) |
192 | if (is_f32) |
193 | dst_col[i] = alpha * src_col[i]; |
194 | else |
195 | dst_col[i] = src_col[i]; |
196 | }); |
197 | } else { |
198 | // Naive code for now. |
199 | parallel_nd(ncols_dst, [=](dim_t j) { |
200 | auto src_col = src + j; |
201 | auto dst_col = dst + j * ld_dst; |
202 | |
203 | PRAGMA_OMP_SIMD() |
204 | for (dim_t i = 0; i < nrows_dst; i++) |
205 | if (is_f32) |
206 | dst_col[i] = alpha * src_col[i * ld_src]; |
207 | else |
208 | dst_col[i] = src_col[i * ld_src]; |
209 | }); |
210 | } |
211 | |
212 | return dnnl_success; |
213 | } |
214 | |
215 | } // namespace gemm_utils |
216 | } // namespace x64 |
217 | } // namespace cpu |
218 | } // namespace impl |
219 | } // namespace dnnl |
220 | |
221 | #endif // CPU_X64_GEMM_GEMM_UTILS_HPP |
222 | |