1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_IP_CONVOLUTION_HPP |
18 | #define CPU_X64_IP_CONVOLUTION_HPP |
19 | |
20 | #include <string> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/primitive.hpp" |
24 | #include "common/primitive_desc_iterator.hpp" |
25 | #include "common/utils.hpp" |
26 | |
27 | #include "cpu/cpu_convolution_pd.hpp" |
28 | #include "cpu/cpu_inner_product_pd.hpp" |
29 | |
30 | #include "cpu/x64/cpu_isa_traits.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace cpu { |
35 | namespace x64 { |
36 | |
37 | namespace { |
38 | |
39 | status_t reshape_dst(memory_desc_t *o_md, const memory_desc_t *i_md) { |
40 | dims_t reduce {}; |
41 | const dim_t ndims = 2; // dst is always nc for inner product |
42 | // conv to ip: remove spatial |
43 | for (int d = 0; d < ndims; ++d) |
44 | reduce[d] = i_md->dims[d]; |
45 | |
46 | return memory_desc_reshape(*o_md, *i_md, ndims, reduce); |
47 | } |
48 | |
49 | status_t maybe_reshape_weights(memory_desc_t *o_md, const memory_desc_t *i_md, |
50 | bool with_groups, bool to_ip = false) { |
51 | dims_t reduce {}; |
52 | const dim_t ndims = i_md->ndims + (to_ip ? -1 : +1) * with_groups; |
53 | if (to_ip) { |
54 | // conv to ip: maybe remove groups |
55 | for (int d = 0; d < ndims; ++d) |
56 | reduce[d] = i_md->dims[d + with_groups]; |
57 | } else { |
58 | // ip to conv: maybe restore groups |
59 | if (with_groups) reduce[0] = 1; |
60 | for (int d = 0; d < ndims; ++d) |
61 | reduce[d + with_groups] = i_md->dims[d]; |
62 | } |
63 | |
64 | return memory_desc_reshape(*o_md, *i_md, ndims, reduce); |
65 | } |
66 | |
67 | status_t check_conv_ip(convolution_pd_t *self) { |
68 | // Check if convolution is equivalent to inner product |
69 | const bool is_ip_applicable = true |
70 | // no dilations |
71 | && utils::everyone_is(0, self->KDD(), self->KDH(), self->KDW()) |
72 | // no "left" padding |
73 | && utils::everyone_is( |
74 | 0, self->padFront(), self->padT(), self->padL()) |
75 | // no "right" padding |
76 | && utils::everyone_is( |
77 | 0, self->padBack(), self->padB(), self->padR()) |
78 | // no non-trivial groups or output spatial |
79 | && utils::everyone_is( |
80 | 1, self->G(), self->OD(), self->OH(), self->OW()) |
81 | // only unit stride |
82 | && utils::everyone_is(1, self->KSD(), self->KSH(), self->KSW()); |
83 | if (!is_ip_applicable) return status::unimplemented; |
84 | |
85 | // Simple heuristic to only target arches and shapes that benefit. |
86 | // TODO: Extend to other arches and shapes as performance allows. |
87 | const dim_t ks = self->KD() * self->KH() * self->KW(); |
88 | const dim_t ks_threshold = 27; // empirical |
89 | const bool is_performant |
90 | = 1 < self->MB() && ks > ks_threshold && mayiuse(avx512_core); |
91 | if (!is_performant) return status::unimplemented; |
92 | |
93 | return status::success; |
94 | } |
95 | |
96 | status_t check_tag(memory_desc_t &md, const format_tag_t tag) { |
97 | const memory_desc_wrapper mdw(&md); |
98 | if (mdw.matches_one_of_tag(tag) == format_tag::undef) |
99 | return status::unimplemented; |
100 | return status::success; |
101 | } |
102 | |
103 | status_t set_and_or_check_formats(const convolution_desc_t &desc, |
104 | memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md, |
105 | memory_desc_t &bias_md, primitive_attr_t &attr) { |
106 | using namespace format_tag; |
107 | auto atag = utils::pick(src_md.ndims - 3, nwc, nhwc, ndhwc); |
108 | const bool is_fwd = utils::one_of(desc.prop_kind, |
109 | prop_kind::forward_training, prop_kind::forward_inference); |
110 | const bool with_bias = desc.prop_kind != prop_kind::backward_data; |
111 | |
112 | // Check that nspc is the default layout for convolutions, |
113 | // or that expected performance gain outweights potential |
114 | // cost of extra reorders. |
115 | // Currently this means: |
116 | // - int8 with any forward prop_kind on any isa |
117 | // - fp32/bf16 with any prop_kind on avx512_core and higher |
118 | // - f16 |
119 | const auto wei_dt = weights_md.data_type; |
120 | const bool is_set_allowed = false |
121 | || (utils::one_of(wei_dt, data_type::f32, data_type::bf16) |
122 | && mayiuse(avx512_core)) |
123 | || (is_fwd && wei_dt == data_type::s8) |
124 | || (wei_dt == data_type::f16 && mayiuse(avx512_core_fp16)); |
125 | |
126 | // NOTE: Only plain layouts should be supported since the dims of |
127 | // dst_md_ must be reshaped from {N, C, H, W} to {N, C}. If the |
128 | // conv layout is blocked by channel, then the ip layout will also |
129 | // be blocked by channel (eg nChw16c -> nC16c). This can lead to |
130 | // deployment of reference ip as well as strange weights layouts. |
131 | if (is_set_allowed && src_md.format_kind == format_kind::any) |
132 | CHECK(memory_desc_init_by_tag(src_md, atag)); |
133 | else |
134 | CHECK(check_tag(src_md, atag)); |
135 | if (is_set_allowed && dst_md.format_kind == format_kind::any) |
136 | CHECK(memory_desc_init_by_tag(dst_md, atag)); |
137 | else |
138 | CHECK(check_tag(dst_md, atag)); |
139 | if (with_bias && bias_md.format_kind != format_kind::undef) { |
140 | auto btag = x; |
141 | if (bias_md.format_kind == format_kind::any) |
142 | CHECK(memory_desc_init_by_tag(bias_md, btag)); |
143 | else |
144 | CHECK(check_tag(bias_md, btag)); |
145 | } |
146 | return attr.set_default_formats(&dst_md); |
147 | } |
148 | |
149 | } // namespace |
150 | |
151 | struct ip_convolution_fwd_t : public primitive_t { |
152 | struct pd_t : public cpu_convolution_fwd_pd_t { |
153 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
154 | const convolution_fwd_pd_t *hint_fwd_pd) |
155 | : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {} |
156 | |
157 | pd_t(const pd_t &other) |
158 | : cpu_convolution_fwd_pd_t(other) |
159 | , ip_pd_(other.ip_pd_->clone()) |
160 | , name_(other.name_) {} |
161 | |
162 | ~pd_t() = default; |
163 | |
164 | DECLARE_COMMON_PD_T(name_.c_str(), ip_convolution_fwd_t); |
165 | |
166 | status_t init_ip(engine_t *engine) { |
167 | inner_product_desc_t ipd; |
168 | CHECK(ip_desc_create(&ipd)); |
169 | primitive_desc_iterator_t it( |
170 | engine, (op_desc_t *)&ipd, attr(), nullptr); |
171 | if (!it.is_initialized()) return status::out_of_memory; |
172 | |
173 | while (++it != it.end()) { |
174 | ip_pd_ = *it; |
175 | const bool ok = ip_pd_->weights_md()->extra.flags == 0; |
176 | if (ok) return status::success; |
177 | } |
178 | return status::unimplemented; |
179 | } |
180 | |
181 | status_t init(engine_t *engine) { |
182 | using namespace format_tag; |
183 | using smask_t = primitive_attr_t::skip_mask_t; |
184 | |
185 | const bool ok = is_fwd() |
186 | && set_default_alg_kind(alg_kind::convolution_direct) |
187 | && attr()->has_default_values( |
188 | smask_t::scales_runtime | smask_t::post_ops); |
189 | if (!ok) return status::unimplemented; |
190 | |
191 | CHECK(check_conv_ip(this)); |
192 | |
193 | CHECK(set_and_or_check_formats( |
194 | *desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_)); |
195 | |
196 | CHECK(init_ip(engine)); |
197 | |
198 | if (weights_md_.format_kind == format_kind::any) |
199 | CHECK(maybe_reshape_weights( |
200 | &weights_md_, ip_pd_->weights_md(), with_groups())); |
201 | |
202 | init_name(); |
203 | init_scratchpad(); |
204 | return status::success; |
205 | } |
206 | |
207 | std::shared_ptr<primitive_desc_t> ip_pd_; |
208 | |
209 | private: |
210 | std::string name_ = "ip:" ; |
211 | |
212 | void init_name() { |
213 | const std::string ips(ip_pd_->name()); |
214 | const std::string prefix = "x64:" ; |
215 | const size_t pos = ips.find(prefix); |
216 | name_.append(ips, pos + prefix.length(), std::string::npos); |
217 | } |
218 | |
219 | void init_scratchpad() { |
220 | using namespace memory_tracking::names; |
221 | auto scratchpad = scratchpad_registry().registrar(); |
222 | scratchpad.book(key_nested, ip_pd_->scratchpad_registry()); |
223 | } |
224 | |
225 | status_t ip_desc_create(inner_product_desc_t *ipd) { |
226 | const bool to_ip = true; |
227 | |
228 | // reinterpret dst without spatial |
229 | memory_desc_t ip_dst_d; |
230 | CHECK(reshape_dst(&ip_dst_d, &dst_md_)); |
231 | |
232 | // reinterpret weights without groups |
233 | memory_desc_t ip_weights_d; |
234 | CHECK(maybe_reshape_weights( |
235 | &ip_weights_d, &weights_md_, with_groups(), to_ip)); |
236 | |
237 | return ip_desc_init(ipd, desc()->prop_kind, &src_md_, &ip_weights_d, |
238 | &bias_md_, &ip_dst_d); |
239 | } |
240 | }; |
241 | |
242 | ip_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} |
243 | |
244 | status_t init(engine_t *engine) override { |
245 | CHECK(pd()->ip_pd_->create_primitive(ip_p_, engine)); |
246 | return status::success; |
247 | } |
248 | |
249 | status_t execute(const exec_ctx_t &ctx) const override; |
250 | |
251 | private: |
252 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
253 | std::shared_ptr<primitive_t> ip_p_; |
254 | }; |
255 | |
256 | struct ip_convolution_bwd_data_t : public primitive_t { |
257 | struct pd_t : public cpu_convolution_bwd_data_pd_t { |
258 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
259 | const convolution_fwd_pd_t *hint_fwd_pd) |
260 | : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {} |
261 | |
262 | pd_t(const pd_t &other) |
263 | : cpu_convolution_bwd_data_pd_t(other) |
264 | , ip_pd_(other.ip_pd_->clone()) {} |
265 | |
266 | ~pd_t() = default; |
267 | |
268 | DECLARE_COMMON_PD_T(name_.c_str(), ip_convolution_bwd_data_t); |
269 | |
270 | status_t init_ip(engine_t *engine) { |
271 | inner_product_desc_t ipd; |
272 | CHECK(ip_desc_create(&ipd)); |
273 | primitive_desc_iterator_t it( |
274 | engine, (op_desc_t *)&ipd, attr(), nullptr); |
275 | if (!it.is_initialized()) return status::out_of_memory; |
276 | while (++it != it.end()) { |
277 | ip_pd_ = *it; |
278 | const bool ok = ip_pd_->weights_md()->extra.flags == 0; |
279 | if (ok) return status::success; |
280 | } |
281 | return status::unimplemented; |
282 | } |
283 | |
284 | status_t init(engine_t *engine) { |
285 | using namespace format_tag; |
286 | |
287 | const bool ok = desc()->prop_kind == prop_kind::backward_data |
288 | && set_default_alg_kind(alg_kind::convolution_direct) |
289 | && attr()->has_default_values(); |
290 | if (!ok) return status::unimplemented; |
291 | |
292 | CHECK(check_conv_ip(this)); |
293 | |
294 | CHECK(set_and_or_check_formats(*desc(), diff_src_md_, weights_md_, |
295 | diff_dst_md_, bias_md_, attr_)); |
296 | |
297 | CHECK(init_ip(engine)); |
298 | |
299 | if (weights_md_.format_kind == format_kind::any) |
300 | CHECK(maybe_reshape_weights( |
301 | &weights_md_, ip_pd_->weights_md(), with_groups())); |
302 | |
303 | init_name(); |
304 | init_scratchpad(); |
305 | return status::success; |
306 | } |
307 | |
308 | std::shared_ptr<primitive_desc_t> ip_pd_; |
309 | |
310 | private: |
311 | std::string name_ = "ip:" ; |
312 | |
313 | void init_name() { name_.append(ip_pd_->name()); } |
314 | |
315 | void init_scratchpad() { |
316 | using namespace memory_tracking::names; |
317 | auto scratchpad = scratchpad_registry().registrar(); |
318 | scratchpad.book(key_nested, ip_pd_->scratchpad_registry()); |
319 | } |
320 | |
321 | status_t ip_desc_create(inner_product_desc_t *ipd) { |
322 | const bool to_ip = true; |
323 | |
324 | // reinterpret dst without spatial |
325 | memory_desc_t ip_diff_dst_d; |
326 | CHECK(reshape_dst(&ip_diff_dst_d, &diff_dst_md_)); |
327 | |
328 | // reinterpret weights without groups |
329 | memory_desc_t ip_weights_d; |
330 | CHECK(maybe_reshape_weights( |
331 | &ip_weights_d, &weights_md_, with_groups(), to_ip)); |
332 | |
333 | return ip_desc_init(ipd, desc()->prop_kind, &diff_src_md_, |
334 | &ip_weights_d, nullptr, &ip_diff_dst_d); |
335 | } |
336 | }; |
337 | |
338 | ip_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} |
339 | |
340 | status_t init(engine_t *engine) override { |
341 | CHECK(pd()->ip_pd_->create_primitive(ip_p_, engine)); |
342 | return status::success; |
343 | } |
344 | |
345 | status_t execute(const exec_ctx_t &ctx) const override; |
346 | |
347 | private: |
348 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
349 | std::shared_ptr<primitive_t> ip_p_; |
350 | }; |
351 | |
352 | struct ip_convolution_bwd_weights_t : public primitive_t { |
353 | struct pd_t : public cpu_convolution_bwd_weights_pd_t { |
354 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
355 | const convolution_fwd_pd_t *hint_fwd_pd) |
356 | : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {} |
357 | |
358 | pd_t(const pd_t &other) |
359 | : cpu_convolution_bwd_weights_pd_t(other) |
360 | , ip_pd_(other.ip_pd_->clone()) {} |
361 | |
362 | ~pd_t() = default; |
363 | |
364 | DECLARE_COMMON_PD_T(name_.c_str(), ip_convolution_bwd_weights_t); |
365 | |
366 | status_t init_ip(engine_t *engine) { |
367 | inner_product_desc_t ipd; |
368 | CHECK(ip_desc_create(&ipd)); |
369 | primitive_desc_iterator_t it( |
370 | engine, (op_desc_t *)&ipd, attr(), nullptr); |
371 | if (!it.is_initialized()) return status::out_of_memory; |
372 | |
373 | while (++it != it.end()) { |
374 | ip_pd_ = *it; |
375 | const bool ok = ip_pd_->weights_md()->extra.flags == 0; |
376 | if (ok) return status::success; |
377 | } |
378 | return status::unimplemented; |
379 | } |
380 | |
381 | status_t init(engine_t *engine) { |
382 | using namespace format_tag; |
383 | |
384 | const bool ok = desc()->prop_kind == prop_kind::backward_weights |
385 | && set_default_alg_kind(alg_kind::convolution_direct) |
386 | && attr()->has_default_values(); |
387 | if (!ok) return status::unimplemented; |
388 | |
389 | CHECK(check_conv_ip(this)); |
390 | |
391 | CHECK(set_and_or_check_formats(*desc(), src_md_, diff_weights_md_, |
392 | diff_dst_md_, diff_bias_md_, attr_)); |
393 | |
394 | CHECK(init_ip(engine)); |
395 | |
396 | if (diff_weights_md_.format_kind == format_kind::any) |
397 | CHECK(maybe_reshape_weights(&diff_weights_md_, |
398 | ip_pd_->diff_weights_md(), with_groups())); |
399 | |
400 | init_name(); |
401 | init_scratchpad(); |
402 | return status::success; |
403 | } |
404 | |
405 | std::shared_ptr<primitive_desc_t> ip_pd_; |
406 | |
407 | private: |
408 | std::string name_ = "ip:" ; |
409 | |
410 | void init_name() { name_.append(ip_pd_->name()); } |
411 | |
412 | void init_scratchpad() { |
413 | using namespace memory_tracking::names; |
414 | auto scratchpad = scratchpad_registry().registrar(); |
415 | scratchpad.book(key_nested, ip_pd_->scratchpad_registry()); |
416 | } |
417 | |
418 | status_t ip_desc_create(inner_product_desc_t *ipd) { |
419 | const bool to_ip = true; |
420 | |
421 | // reinterpret dst without spatial |
422 | memory_desc_t ip_diff_dst_d; |
423 | CHECK(reshape_dst(&ip_diff_dst_d, &diff_dst_md_)); |
424 | |
425 | // reinterpret weights without groups |
426 | memory_desc_t ip_diff_weights_d; |
427 | CHECK(maybe_reshape_weights(&ip_diff_weights_d, &diff_weights_md_, |
428 | with_groups(), to_ip)); |
429 | |
430 | return ip_desc_init(ipd, desc()->prop_kind, &src_md_, |
431 | &ip_diff_weights_d, &diff_bias_md_, &ip_diff_dst_d); |
432 | } |
433 | }; |
434 | |
435 | ip_convolution_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {} |
436 | |
437 | status_t init(engine_t *engine) override { |
438 | CHECK(pd()->ip_pd_->create_primitive(ip_p_, engine)); |
439 | return status::success; |
440 | } |
441 | |
442 | status_t execute(const exec_ctx_t &ctx) const override; |
443 | |
444 | private: |
445 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
446 | std::shared_ptr<primitive_t> ip_p_; |
447 | }; |
448 | |
449 | } // namespace x64 |
450 | } // namespace cpu |
451 | } // namespace impl |
452 | } // namespace dnnl |
453 | |
454 | #endif |
455 | |
456 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
457 | |