1/*******************************************************************************
2* Copyright 2016-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_CPU_INNER_PRODUCT_PD_HPP
18#define CPU_CPU_INNER_PRODUCT_PD_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/inner_product_pd.hpp"
24#include "common/utils.hpp"
25#include "cpu/cpu_engine.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30
31namespace {
32inline bool dense_gemm_consitency_check(const memory_desc_wrapper &src_d,
33 const memory_desc_wrapper &wei_d, const memory_desc_wrapper &dst_d) {
34 using namespace utils;
35
36 auto strides_compatible = [&]() {
37 bool ok = true;
38 auto w_str = wei_d.blocking_desc().strides;
39 auto d_str = src_d.blocking_desc().strides;
40 for (int i = 1; i < src_d.ndims() - 1; i++) {
41 ok = ok && w_str[i] / d_str[i] == w_str[i + 1] / d_str[i + 1];
42 }
43 return ok && one_of(w_str[1] / d_str[1], 1, wei_d.padded_dims()[0]);
44 };
45
46 auto inner_blk_compatible = [&]() {
47 auto d_inner_blks = src_d.blocking_desc().inner_blks;
48 auto w_inner_blks = wei_d.blocking_desc().inner_blks;
49 auto d_inner_idxs = src_d.blocking_desc().inner_idxs;
50 auto w_inner_idxs = wei_d.blocking_desc().inner_idxs;
51
52 int d_inner_nblks = src_d.blocking_desc().inner_nblks;
53 int w_inner_nblks = wei_d.blocking_desc().inner_nblks;
54
55 bool ok = true;
56
57 if ((wei_d.blocking_desc().strides[0] == 1) && (w_inner_nblks > 0)) {
58 ok = ok && wei_d.dims()[0] / w_inner_blks[w_inner_nblks - 1] == 1
59 && w_inner_idxs[w_inner_nblks - 1] == 0;
60 w_inner_nblks--;
61 }
62 ok = ok && d_inner_nblks == w_inner_nblks;
63
64 for (int d = 0; d < w_inner_nblks; d++)
65 ok = ok && (d_inner_blks[d] == w_inner_blks[d])
66 && (d_inner_idxs[d] == w_inner_idxs[d]);
67
68 return ok;
69 };
70
71 return true && src_d.is_blocking_desc() && wei_d.is_blocking_desc()
72 && src_d.ndims() == wei_d.ndims() && inner_blk_compatible()
73 && strides_compatible() && dst_d.matches_tag(format_tag::nc)
74 && src_d.only_padded_dim(1) && wei_d.only_padded_dim(1)
75 && src_d.padded_dims()[1] == wei_d.padded_dims()[1]
76 && src_d.is_dense(true) && dst_d.is_dense() && wei_d.is_dense(true);
77}
78
79void transpose_md(memory_desc_t &md) {
80 // Note: we cannot directly use good leading dimension for a
81 // in padded_dims. This is because inner_blks does not
82 // account for padding, and should divide the corresponding
83 // padded_dim.
84 auto put_a_last = [](memory_desc_t &md) {
85 auto &md_blk = md.format_desc.blocking;
86 md.padded_dims[0] = md.dims[0];
87 md_blk.strides[0] = 1;
88 for (int d = 1; d < md.ndims; d++)
89 md_blk.strides[d] *= md.padded_dims[0];
90 if (md_blk.inner_nblks > 0) {
91 md_blk.inner_idxs[md_blk.inner_nblks] = 0;
92 md_blk.inner_blks[md_blk.inner_nblks] = md.padded_dims[0];
93 md_blk.inner_nblks++;
94 }
95 };
96
97 auto put_a_first = [](memory_desc_t &md) {
98 blocking_desc_t blk = md.format_desc.blocking;
99 // make the stride for `a` bigger than any other stride and
100 // use the fact that memory_desc_init_by_blocking_desc
101 // preserves the strides order but actually changes them to
102 // densify the descriptor
103 blk.strides[0] = memory_desc_wrapper(md).size();
104 memory_desc_init_by_blocking_desc(md, blk);
105 };
106
107 auto is_a_last = [](memory_desc_t &md) {
108 auto &md_blk = md.format_desc.blocking;
109 // The inner_blks condition makes sure that a is a non blocked dimension
110 return (md_blk.strides[0] == 1) && (md_blk.inner_nblks == 0);
111 };
112
113 auto is_a_first = [&](memory_desc_t &md) {
114 auto &md_blk = md.format_desc.blocking;
115 for (int d = 1; d < md.ndims; d++)
116 if (md_blk.strides[0] < md_blk.strides[d]) return false;
117 return true;
118 };
119
120 if (is_a_last(md))
121 put_a_first(md);
122 else if (is_a_first(md))
123 put_a_last(md);
124
125 // here, by default we do not transpose md if it is not
126}
127
128format_tag_t get_tag(memory_desc_t &md) {
129 using namespace format_tag;
130 auto tag = memory_desc_matches_one_of_tag(md, ab, abc, abcd,
131 abcde, // NCHW derivatives
132 ba, bca, bcda, bcdea, cba, cdba,
133 cdeba, // IO and spatial derivatives
134 acb, acdb, acdeb, // NHWC derivatives
135 aBcd16b, aBcde16b, aBcd8b, aBcde8b, aBcd4b,
136 aBcde4b); // blocked layouts
137 return tag;
138}
139
140inline bool is_ineff_lead_dim(const dim_t dim) {
141 return dim % 1024 == 0; // check cache aliasing
142}
143
144/* Pick between M and K for the most efficient leading
145 * dimension to compute GeMM. */
146bool transpose_leading_dim(const dim_t M, const dim_t K) {
147 return IMPLICATION(is_ineff_lead_dim(M), is_ineff_lead_dim(K) && M <= K);
148}
149} // namespace
150
151#define INIT_MEM_BY_TAG(tag_init_f, md) \
152 do { \
153 auto tag = tag_init_f; \
154 if (tag == format_tag::undef) return status::unimplemented; \
155 CHECK(memory_desc_init_by_tag(md, tag)); \
156 } while (0)
157
158struct cpu_inner_product_fwd_pd_t : public inner_product_fwd_pd_t {
159 using inner_product_fwd_pd_t::inner_product_fwd_pd_t;
160
161protected:
162 status_t set_default_params(bool allow_all_tags = false) {
163 using namespace format_tag;
164
165 auto set_default_src = [&]() {
166 if (weights_md_.format_kind == format_kind::any) {
167 INIT_MEM_BY_TAG(utils::pick(ndims() - 2, ab, abc, abcd, abcde),
168 src_md_);
169 } else {
170 format_tag_t weights_tag = get_tag(weights_md_);
171 if (allow_all_tags && weights_tag == undef) {
172 INIT_MEM_BY_TAG(
173 utils::pick(ndims() - 2, ab, abc, abcd, abcde),
174 src_md_);
175 } else {
176 INIT_MEM_BY_TAG(weights_tag, src_md_);
177 }
178 // transpose weights to improve efficiency of non-copy kernels
179 if (src_md_.format_desc.blocking.strides[0] == 1)
180 transpose_md(src_md_);
181 }
182 return status::success;
183 };
184
185 auto set_default_weights = [&]() {
186 format_tag_t src_tag = get_tag(src_md_);
187 if (allow_all_tags && src_tag == undef) {
188 INIT_MEM_BY_TAG(utils::pick(ndims() - 2, ab, abc, abcd, abcde),
189 weights_md_);
190 } else {
191 INIT_MEM_BY_TAG(src_tag, weights_md_);
192 }
193 /* with batch = 1, no transpose to use the faster gemv kernels */
194 /* otherwise, we transpose the weights to improve efficiency of
195 * no-copy kernels */
196 if (MB() > 1 && transpose_leading_dim(OC(), IC_total()))
197 transpose_md(weights_md_);
198 return status::success;
199 };
200
201 if (src_md_.format_kind == format_kind::any) CHECK(set_default_src());
202 if (weights_md_.format_kind == format_kind::any)
203 CHECK(set_default_weights());
204 if (dst_md_.format_kind == format_kind::any)
205 CHECK(memory_desc_init_by_tag(dst_md_, nc));
206 if (bias_md_.format_kind == format_kind::any)
207 CHECK(memory_desc_init_by_tag(bias_md_, x));
208 return status::success;
209 }
210};
211
212struct cpu_inner_product_bwd_data_pd_t : public inner_product_bwd_data_pd_t {
213 using inner_product_bwd_data_pd_t::inner_product_bwd_data_pd_t;
214
215protected:
216 status_t set_default_params(bool allow_all_tags = false) {
217 using namespace format_tag;
218
219 auto set_default_diff_src = [&]() {
220 if (weights_md_.format_kind == format_kind::any) {
221 INIT_MEM_BY_TAG(utils::pick(ndims() - 2, ab, abc, abcd, abcde),
222 diff_src_md_);
223 } else {
224 format_tag_t weights_tag = get_tag(weights_md_);
225 if (allow_all_tags && weights_tag == undef) {
226 INIT_MEM_BY_TAG(
227 utils::pick(ndims() - 2, ab, abc, abcd, abcde),
228 diff_src_md_);
229 } else {
230 INIT_MEM_BY_TAG(weights_tag, diff_src_md_);
231 }
232 if (diff_src_md_.format_desc.blocking.strides[0] == 1)
233 transpose_md(diff_src_md_);
234 }
235 return status::success;
236 };
237
238 auto set_default_weights = [&]() {
239 format_tag_t diff_src_tag = get_tag(diff_src_md_);
240 if (allow_all_tags && diff_src_tag == undef) {
241 INIT_MEM_BY_TAG(utils::pick(ndims() - 2, ab, abc, abcd, abcde),
242 weights_md_);
243 } else {
244 INIT_MEM_BY_TAG(diff_src_tag, weights_md_);
245 }
246 /* with batch = 1, no transpose to use the faster gemv kernels */
247 /* otherwise, we transpose the weights to improve efficiency of
248 * no-copy kernels */
249 if (MB() == 1) transpose_md(weights_md_);
250
251 return status::success;
252 };
253
254 if (diff_src_md_.format_kind == format_kind::any)
255 CHECK(set_default_diff_src());
256 if (weights_md_.format_kind == format_kind::any)
257 CHECK(set_default_weights());
258 if (diff_dst_md_.format_kind == format_kind::any)
259 CHECK(memory_desc_init_by_tag(diff_dst_md_, nc));
260 return status::success;
261 }
262};
263
264struct cpu_inner_product_bwd_weights_pd_t
265 : public inner_product_bwd_weights_pd_t {
266 using inner_product_bwd_weights_pd_t::inner_product_bwd_weights_pd_t;
267
268protected:
269 status_t set_default_params(bool allow_all_tags = false) {
270 using namespace format_tag;
271
272 auto set_default_src = [&]() {
273 if (diff_weights_md_.format_kind == format_kind::any) {
274 INIT_MEM_BY_TAG(utils::pick(ndims() - 2, ab, abc, abcd, abcde),
275 src_md_);
276 } else {
277 format_tag_t diff_weights_tag = get_tag(diff_weights_md_);
278 if (allow_all_tags && diff_weights_tag == undef) {
279 INIT_MEM_BY_TAG(
280 utils::pick(ndims() - 2, ab, abc, abcd, abcde),
281 src_md_);
282 } else {
283 INIT_MEM_BY_TAG(diff_weights_tag, src_md_);
284 }
285 if (src_md_.format_desc.blocking.strides[0] == 1)
286 transpose_md(src_md_);
287 }
288 return status::success;
289 };
290
291 auto set_default_diff_weights = [&]() {
292 format_tag_t src_tag = get_tag(src_md_);
293 if (allow_all_tags && src_tag == undef) {
294 INIT_MEM_BY_TAG(utils::pick(ndims() - 2, ab, abc, abcd, abcde),
295 diff_weights_md_);
296 } else {
297 INIT_MEM_BY_TAG(src_tag, diff_weights_md_);
298 }
299 // Here, we want diff_weights layout to match the fwd weights layout
300 if (MB() > 1 && transpose_leading_dim(OC(), MB()))
301 transpose_md(diff_weights_md_);
302 return status::success;
303 };
304
305 if (src_md_.format_kind == format_kind::any) CHECK(set_default_src());
306 if (diff_weights_md_.format_kind == format_kind::any)
307 CHECK(set_default_diff_weights());
308 if (diff_dst_md_.format_kind == format_kind::any)
309 CHECK(memory_desc_init_by_tag(diff_dst_md_, nc));
310 if (diff_bias_md_.format_kind == format_kind::any)
311 CHECK(memory_desc_init_by_tag(diff_bias_md_, x));
312 return status::success;
313 }
314};
315#undef INIT_MEM_BY_TAG
316
317} // namespace cpu
318} // namespace impl
319} // namespace dnnl
320
321#endif
322
323// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
324