1 | /******************************************************************************* |
2 | * Copyright 2016-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_CPU_INNER_PRODUCT_PD_HPP |
18 | #define CPU_CPU_INNER_PRODUCT_PD_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/inner_product_pd.hpp" |
24 | #include "common/utils.hpp" |
25 | #include "cpu/cpu_engine.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | |
31 | namespace { |
32 | inline bool dense_gemm_consitency_check(const memory_desc_wrapper &src_d, |
33 | const memory_desc_wrapper &wei_d, const memory_desc_wrapper &dst_d) { |
34 | using namespace utils; |
35 | |
36 | auto strides_compatible = [&]() { |
37 | bool ok = true; |
38 | auto w_str = wei_d.blocking_desc().strides; |
39 | auto d_str = src_d.blocking_desc().strides; |
40 | for (int i = 1; i < src_d.ndims() - 1; i++) { |
41 | ok = ok && w_str[i] / d_str[i] == w_str[i + 1] / d_str[i + 1]; |
42 | } |
43 | return ok && one_of(w_str[1] / d_str[1], 1, wei_d.padded_dims()[0]); |
44 | }; |
45 | |
46 | auto inner_blk_compatible = [&]() { |
47 | auto d_inner_blks = src_d.blocking_desc().inner_blks; |
48 | auto w_inner_blks = wei_d.blocking_desc().inner_blks; |
49 | auto d_inner_idxs = src_d.blocking_desc().inner_idxs; |
50 | auto w_inner_idxs = wei_d.blocking_desc().inner_idxs; |
51 | |
52 | int d_inner_nblks = src_d.blocking_desc().inner_nblks; |
53 | int w_inner_nblks = wei_d.blocking_desc().inner_nblks; |
54 | |
55 | bool ok = true; |
56 | |
57 | if ((wei_d.blocking_desc().strides[0] == 1) && (w_inner_nblks > 0)) { |
58 | ok = ok && wei_d.dims()[0] / w_inner_blks[w_inner_nblks - 1] == 1 |
59 | && w_inner_idxs[w_inner_nblks - 1] == 0; |
60 | w_inner_nblks--; |
61 | } |
62 | ok = ok && d_inner_nblks == w_inner_nblks; |
63 | |
64 | for (int d = 0; d < w_inner_nblks; d++) |
65 | ok = ok && (d_inner_blks[d] == w_inner_blks[d]) |
66 | && (d_inner_idxs[d] == w_inner_idxs[d]); |
67 | |
68 | return ok; |
69 | }; |
70 | |
71 | return true && src_d.is_blocking_desc() && wei_d.is_blocking_desc() |
72 | && src_d.ndims() == wei_d.ndims() && inner_blk_compatible() |
73 | && strides_compatible() && dst_d.matches_tag(format_tag::nc) |
74 | && src_d.only_padded_dim(1) && wei_d.only_padded_dim(1) |
75 | && src_d.padded_dims()[1] == wei_d.padded_dims()[1] |
76 | && src_d.is_dense(true) && dst_d.is_dense() && wei_d.is_dense(true); |
77 | } |
78 | |
79 | void transpose_md(memory_desc_t &md) { |
80 | // Note: we cannot directly use good leading dimension for a |
81 | // in padded_dims. This is because inner_blks does not |
82 | // account for padding, and should divide the corresponding |
83 | // padded_dim. |
84 | auto put_a_last = [](memory_desc_t &md) { |
85 | auto &md_blk = md.format_desc.blocking; |
86 | md.padded_dims[0] = md.dims[0]; |
87 | md_blk.strides[0] = 1; |
88 | for (int d = 1; d < md.ndims; d++) |
89 | md_blk.strides[d] *= md.padded_dims[0]; |
90 | if (md_blk.inner_nblks > 0) { |
91 | md_blk.inner_idxs[md_blk.inner_nblks] = 0; |
92 | md_blk.inner_blks[md_blk.inner_nblks] = md.padded_dims[0]; |
93 | md_blk.inner_nblks++; |
94 | } |
95 | }; |
96 | |
97 | auto put_a_first = [](memory_desc_t &md) { |
98 | blocking_desc_t blk = md.format_desc.blocking; |
99 | // make the stride for `a` bigger than any other stride and |
100 | // use the fact that memory_desc_init_by_blocking_desc |
101 | // preserves the strides order but actually changes them to |
102 | // densify the descriptor |
103 | blk.strides[0] = memory_desc_wrapper(md).size(); |
104 | memory_desc_init_by_blocking_desc(md, blk); |
105 | }; |
106 | |
107 | auto is_a_last = [](memory_desc_t &md) { |
108 | auto &md_blk = md.format_desc.blocking; |
109 | // The inner_blks condition makes sure that a is a non blocked dimension |
110 | return (md_blk.strides[0] == 1) && (md_blk.inner_nblks == 0); |
111 | }; |
112 | |
113 | auto is_a_first = [&](memory_desc_t &md) { |
114 | auto &md_blk = md.format_desc.blocking; |
115 | for (int d = 1; d < md.ndims; d++) |
116 | if (md_blk.strides[0] < md_blk.strides[d]) return false; |
117 | return true; |
118 | }; |
119 | |
120 | if (is_a_last(md)) |
121 | put_a_first(md); |
122 | else if (is_a_first(md)) |
123 | put_a_last(md); |
124 | |
125 | // here, by default we do not transpose md if it is not |
126 | } |
127 | |
128 | format_tag_t get_tag(memory_desc_t &md) { |
129 | using namespace format_tag; |
130 | auto tag = memory_desc_matches_one_of_tag(md, ab, abc, abcd, |
131 | abcde, // NCHW derivatives |
132 | ba, bca, bcda, bcdea, cba, cdba, |
133 | cdeba, // IO and spatial derivatives |
134 | acb, acdb, acdeb, // NHWC derivatives |
135 | aBcd16b, aBcde16b, aBcd8b, aBcde8b, aBcd4b, |
136 | aBcde4b); // blocked layouts |
137 | return tag; |
138 | } |
139 | |
140 | inline bool is_ineff_lead_dim(const dim_t dim) { |
141 | return dim % 1024 == 0; // check cache aliasing |
142 | } |
143 | |
144 | /* Pick between M and K for the most efficient leading |
145 | * dimension to compute GeMM. */ |
146 | bool transpose_leading_dim(const dim_t M, const dim_t K) { |
147 | return IMPLICATION(is_ineff_lead_dim(M), is_ineff_lead_dim(K) && M <= K); |
148 | } |
149 | } // namespace |
150 | |
151 | #define INIT_MEM_BY_TAG(tag_init_f, md) \ |
152 | do { \ |
153 | auto tag = tag_init_f; \ |
154 | if (tag == format_tag::undef) return status::unimplemented; \ |
155 | CHECK(memory_desc_init_by_tag(md, tag)); \ |
156 | } while (0) |
157 | |
158 | struct cpu_inner_product_fwd_pd_t : public inner_product_fwd_pd_t { |
159 | using inner_product_fwd_pd_t::inner_product_fwd_pd_t; |
160 | |
161 | protected: |
162 | status_t set_default_params(bool allow_all_tags = false) { |
163 | using namespace format_tag; |
164 | |
165 | auto set_default_src = [&]() { |
166 | if (weights_md_.format_kind == format_kind::any) { |
167 | INIT_MEM_BY_TAG(utils::pick(ndims() - 2, ab, abc, abcd, abcde), |
168 | src_md_); |
169 | } else { |
170 | format_tag_t weights_tag = get_tag(weights_md_); |
171 | if (allow_all_tags && weights_tag == undef) { |
172 | INIT_MEM_BY_TAG( |
173 | utils::pick(ndims() - 2, ab, abc, abcd, abcde), |
174 | src_md_); |
175 | } else { |
176 | INIT_MEM_BY_TAG(weights_tag, src_md_); |
177 | } |
178 | // transpose weights to improve efficiency of non-copy kernels |
179 | if (src_md_.format_desc.blocking.strides[0] == 1) |
180 | transpose_md(src_md_); |
181 | } |
182 | return status::success; |
183 | }; |
184 | |
185 | auto set_default_weights = [&]() { |
186 | format_tag_t src_tag = get_tag(src_md_); |
187 | if (allow_all_tags && src_tag == undef) { |
188 | INIT_MEM_BY_TAG(utils::pick(ndims() - 2, ab, abc, abcd, abcde), |
189 | weights_md_); |
190 | } else { |
191 | INIT_MEM_BY_TAG(src_tag, weights_md_); |
192 | } |
193 | /* with batch = 1, no transpose to use the faster gemv kernels */ |
194 | /* otherwise, we transpose the weights to improve efficiency of |
195 | * no-copy kernels */ |
196 | if (MB() > 1 && transpose_leading_dim(OC(), IC_total())) |
197 | transpose_md(weights_md_); |
198 | return status::success; |
199 | }; |
200 | |
201 | if (src_md_.format_kind == format_kind::any) CHECK(set_default_src()); |
202 | if (weights_md_.format_kind == format_kind::any) |
203 | CHECK(set_default_weights()); |
204 | if (dst_md_.format_kind == format_kind::any) |
205 | CHECK(memory_desc_init_by_tag(dst_md_, nc)); |
206 | if (bias_md_.format_kind == format_kind::any) |
207 | CHECK(memory_desc_init_by_tag(bias_md_, x)); |
208 | return status::success; |
209 | } |
210 | }; |
211 | |
212 | struct cpu_inner_product_bwd_data_pd_t : public inner_product_bwd_data_pd_t { |
213 | using inner_product_bwd_data_pd_t::inner_product_bwd_data_pd_t; |
214 | |
215 | protected: |
216 | status_t set_default_params(bool allow_all_tags = false) { |
217 | using namespace format_tag; |
218 | |
219 | auto set_default_diff_src = [&]() { |
220 | if (weights_md_.format_kind == format_kind::any) { |
221 | INIT_MEM_BY_TAG(utils::pick(ndims() - 2, ab, abc, abcd, abcde), |
222 | diff_src_md_); |
223 | } else { |
224 | format_tag_t weights_tag = get_tag(weights_md_); |
225 | if (allow_all_tags && weights_tag == undef) { |
226 | INIT_MEM_BY_TAG( |
227 | utils::pick(ndims() - 2, ab, abc, abcd, abcde), |
228 | diff_src_md_); |
229 | } else { |
230 | INIT_MEM_BY_TAG(weights_tag, diff_src_md_); |
231 | } |
232 | if (diff_src_md_.format_desc.blocking.strides[0] == 1) |
233 | transpose_md(diff_src_md_); |
234 | } |
235 | return status::success; |
236 | }; |
237 | |
238 | auto set_default_weights = [&]() { |
239 | format_tag_t diff_src_tag = get_tag(diff_src_md_); |
240 | if (allow_all_tags && diff_src_tag == undef) { |
241 | INIT_MEM_BY_TAG(utils::pick(ndims() - 2, ab, abc, abcd, abcde), |
242 | weights_md_); |
243 | } else { |
244 | INIT_MEM_BY_TAG(diff_src_tag, weights_md_); |
245 | } |
246 | /* with batch = 1, no transpose to use the faster gemv kernels */ |
247 | /* otherwise, we transpose the weights to improve efficiency of |
248 | * no-copy kernels */ |
249 | if (MB() == 1) transpose_md(weights_md_); |
250 | |
251 | return status::success; |
252 | }; |
253 | |
254 | if (diff_src_md_.format_kind == format_kind::any) |
255 | CHECK(set_default_diff_src()); |
256 | if (weights_md_.format_kind == format_kind::any) |
257 | CHECK(set_default_weights()); |
258 | if (diff_dst_md_.format_kind == format_kind::any) |
259 | CHECK(memory_desc_init_by_tag(diff_dst_md_, nc)); |
260 | return status::success; |
261 | } |
262 | }; |
263 | |
264 | struct cpu_inner_product_bwd_weights_pd_t |
265 | : public inner_product_bwd_weights_pd_t { |
266 | using inner_product_bwd_weights_pd_t::inner_product_bwd_weights_pd_t; |
267 | |
268 | protected: |
269 | status_t set_default_params(bool allow_all_tags = false) { |
270 | using namespace format_tag; |
271 | |
272 | auto set_default_src = [&]() { |
273 | if (diff_weights_md_.format_kind == format_kind::any) { |
274 | INIT_MEM_BY_TAG(utils::pick(ndims() - 2, ab, abc, abcd, abcde), |
275 | src_md_); |
276 | } else { |
277 | format_tag_t diff_weights_tag = get_tag(diff_weights_md_); |
278 | if (allow_all_tags && diff_weights_tag == undef) { |
279 | INIT_MEM_BY_TAG( |
280 | utils::pick(ndims() - 2, ab, abc, abcd, abcde), |
281 | src_md_); |
282 | } else { |
283 | INIT_MEM_BY_TAG(diff_weights_tag, src_md_); |
284 | } |
285 | if (src_md_.format_desc.blocking.strides[0] == 1) |
286 | transpose_md(src_md_); |
287 | } |
288 | return status::success; |
289 | }; |
290 | |
291 | auto set_default_diff_weights = [&]() { |
292 | format_tag_t src_tag = get_tag(src_md_); |
293 | if (allow_all_tags && src_tag == undef) { |
294 | INIT_MEM_BY_TAG(utils::pick(ndims() - 2, ab, abc, abcd, abcde), |
295 | diff_weights_md_); |
296 | } else { |
297 | INIT_MEM_BY_TAG(src_tag, diff_weights_md_); |
298 | } |
299 | // Here, we want diff_weights layout to match the fwd weights layout |
300 | if (MB() > 1 && transpose_leading_dim(OC(), MB())) |
301 | transpose_md(diff_weights_md_); |
302 | return status::success; |
303 | }; |
304 | |
305 | if (src_md_.format_kind == format_kind::any) CHECK(set_default_src()); |
306 | if (diff_weights_md_.format_kind == format_kind::any) |
307 | CHECK(set_default_diff_weights()); |
308 | if (diff_dst_md_.format_kind == format_kind::any) |
309 | CHECK(memory_desc_init_by_tag(diff_dst_md_, nc)); |
310 | if (diff_bias_md_.format_kind == format_kind::any) |
311 | CHECK(memory_desc_init_by_tag(diff_bias_md_, x)); |
312 | return status::success; |
313 | } |
314 | }; |
315 | #undef INIT_MEM_BY_TAG |
316 | |
317 | } // namespace cpu |
318 | } // namespace impl |
319 | } // namespace dnnl |
320 | |
321 | #endif |
322 | |
323 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
324 | |