1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_IP_CONVOLUTION_HPP
18#define CPU_X64_IP_CONVOLUTION_HPP
19
20#include <string>
21
22#include "common/c_types_map.hpp"
23#include "common/primitive.hpp"
24#include "common/primitive_desc_iterator.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/cpu_convolution_pd.hpp"
28#include "cpu/cpu_inner_product_pd.hpp"
29
30#include "cpu/x64/cpu_isa_traits.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35namespace x64 {
36
37namespace {
38
39status_t reshape_dst(memory_desc_t *o_md, const memory_desc_t *i_md) {
40 dims_t reduce {};
41 const dim_t ndims = 2; // dst is always nc for inner product
42 // conv to ip: remove spatial
43 for (int d = 0; d < ndims; ++d)
44 reduce[d] = i_md->dims[d];
45
46 return memory_desc_reshape(*o_md, *i_md, ndims, reduce);
47}
48
49status_t maybe_reshape_weights(memory_desc_t *o_md, const memory_desc_t *i_md,
50 bool with_groups, bool to_ip = false) {
51 dims_t reduce {};
52 const dim_t ndims = i_md->ndims + (to_ip ? -1 : +1) * with_groups;
53 if (to_ip) {
54 // conv to ip: maybe remove groups
55 for (int d = 0; d < ndims; ++d)
56 reduce[d] = i_md->dims[d + with_groups];
57 } else {
58 // ip to conv: maybe restore groups
59 if (with_groups) reduce[0] = 1;
60 for (int d = 0; d < ndims; ++d)
61 reduce[d + with_groups] = i_md->dims[d];
62 }
63
64 return memory_desc_reshape(*o_md, *i_md, ndims, reduce);
65}
66
67status_t check_conv_ip(convolution_pd_t *self) {
68 // Check if convolution is equivalent to inner product
69 const bool is_ip_applicable = true
70 // no dilations
71 && utils::everyone_is(0, self->KDD(), self->KDH(), self->KDW())
72 // no "left" padding
73 && utils::everyone_is(
74 0, self->padFront(), self->padT(), self->padL())
75 // no "right" padding
76 && utils::everyone_is(
77 0, self->padBack(), self->padB(), self->padR())
78 // no non-trivial groups or output spatial
79 && utils::everyone_is(
80 1, self->G(), self->OD(), self->OH(), self->OW())
81 // only unit stride
82 && utils::everyone_is(1, self->KSD(), self->KSH(), self->KSW());
83 if (!is_ip_applicable) return status::unimplemented;
84
85 // Simple heuristic to only target arches and shapes that benefit.
86 // TODO: Extend to other arches and shapes as performance allows.
87 const dim_t ks = self->KD() * self->KH() * self->KW();
88 const dim_t ks_threshold = 27; // empirical
89 const bool is_performant
90 = 1 < self->MB() && ks > ks_threshold && mayiuse(avx512_core);
91 if (!is_performant) return status::unimplemented;
92
93 return status::success;
94}
95
96status_t check_tag(memory_desc_t &md, const format_tag_t tag) {
97 const memory_desc_wrapper mdw(&md);
98 if (mdw.matches_one_of_tag(tag) == format_tag::undef)
99 return status::unimplemented;
100 return status::success;
101}
102
103status_t set_and_or_check_formats(const convolution_desc_t &desc,
104 memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md,
105 memory_desc_t &bias_md, primitive_attr_t &attr) {
106 using namespace format_tag;
107 auto atag = utils::pick(src_md.ndims - 3, nwc, nhwc, ndhwc);
108 const bool is_fwd = utils::one_of(desc.prop_kind,
109 prop_kind::forward_training, prop_kind::forward_inference);
110 const bool with_bias = desc.prop_kind != prop_kind::backward_data;
111
112 // Check that nspc is the default layout for convolutions,
113 // or that expected performance gain outweights potential
114 // cost of extra reorders.
115 // Currently this means:
116 // - int8 with any forward prop_kind on any isa
117 // - fp32/bf16 with any prop_kind on avx512_core and higher
118 // - f16
119 const auto wei_dt = weights_md.data_type;
120 const bool is_set_allowed = false
121 || (utils::one_of(wei_dt, data_type::f32, data_type::bf16)
122 && mayiuse(avx512_core))
123 || (is_fwd && wei_dt == data_type::s8)
124 || (wei_dt == data_type::f16 && mayiuse(avx512_core_fp16));
125
126 // NOTE: Only plain layouts should be supported since the dims of
127 // dst_md_ must be reshaped from {N, C, H, W} to {N, C}. If the
128 // conv layout is blocked by channel, then the ip layout will also
129 // be blocked by channel (eg nChw16c -> nC16c). This can lead to
130 // deployment of reference ip as well as strange weights layouts.
131 if (is_set_allowed && src_md.format_kind == format_kind::any)
132 CHECK(memory_desc_init_by_tag(src_md, atag));
133 else
134 CHECK(check_tag(src_md, atag));
135 if (is_set_allowed && dst_md.format_kind == format_kind::any)
136 CHECK(memory_desc_init_by_tag(dst_md, atag));
137 else
138 CHECK(check_tag(dst_md, atag));
139 if (with_bias && bias_md.format_kind != format_kind::undef) {
140 auto btag = x;
141 if (bias_md.format_kind == format_kind::any)
142 CHECK(memory_desc_init_by_tag(bias_md, btag));
143 else
144 CHECK(check_tag(bias_md, btag));
145 }
146 return attr.set_default_formats(&dst_md);
147}
148
149} // namespace
150
151struct ip_convolution_fwd_t : public primitive_t {
152 struct pd_t : public cpu_convolution_fwd_pd_t {
153 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
154 const convolution_fwd_pd_t *hint_fwd_pd)
155 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
156
157 pd_t(const pd_t &other)
158 : cpu_convolution_fwd_pd_t(other)
159 , ip_pd_(other.ip_pd_->clone())
160 , name_(other.name_) {}
161
162 ~pd_t() = default;
163
164 DECLARE_COMMON_PD_T(name_.c_str(), ip_convolution_fwd_t);
165
166 status_t init_ip(engine_t *engine) {
167 inner_product_desc_t ipd;
168 CHECK(ip_desc_create(&ipd));
169 primitive_desc_iterator_t it(
170 engine, (op_desc_t *)&ipd, attr(), nullptr);
171 if (!it.is_initialized()) return status::out_of_memory;
172
173 while (++it != it.end()) {
174 ip_pd_ = *it;
175 const bool ok = ip_pd_->weights_md()->extra.flags == 0;
176 if (ok) return status::success;
177 }
178 return status::unimplemented;
179 }
180
181 status_t init(engine_t *engine) {
182 using namespace format_tag;
183 using smask_t = primitive_attr_t::skip_mask_t;
184
185 const bool ok = is_fwd()
186 && set_default_alg_kind(alg_kind::convolution_direct)
187 && attr()->has_default_values(
188 smask_t::scales_runtime | smask_t::post_ops);
189 if (!ok) return status::unimplemented;
190
191 CHECK(check_conv_ip(this));
192
193 CHECK(set_and_or_check_formats(
194 *desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_));
195
196 CHECK(init_ip(engine));
197
198 if (weights_md_.format_kind == format_kind::any)
199 CHECK(maybe_reshape_weights(
200 &weights_md_, ip_pd_->weights_md(), with_groups()));
201
202 init_name();
203 init_scratchpad();
204 return status::success;
205 }
206
207 std::shared_ptr<primitive_desc_t> ip_pd_;
208
209 private:
210 std::string name_ = "ip:";
211
212 void init_name() {
213 const std::string ips(ip_pd_->name());
214 const std::string prefix = "x64:";
215 const size_t pos = ips.find(prefix);
216 name_.append(ips, pos + prefix.length(), std::string::npos);
217 }
218
219 void init_scratchpad() {
220 using namespace memory_tracking::names;
221 auto scratchpad = scratchpad_registry().registrar();
222 scratchpad.book(key_nested, ip_pd_->scratchpad_registry());
223 }
224
225 status_t ip_desc_create(inner_product_desc_t *ipd) {
226 const bool to_ip = true;
227
228 // reinterpret dst without spatial
229 memory_desc_t ip_dst_d;
230 CHECK(reshape_dst(&ip_dst_d, &dst_md_));
231
232 // reinterpret weights without groups
233 memory_desc_t ip_weights_d;
234 CHECK(maybe_reshape_weights(
235 &ip_weights_d, &weights_md_, with_groups(), to_ip));
236
237 return ip_desc_init(ipd, desc()->prop_kind, &src_md_, &ip_weights_d,
238 &bias_md_, &ip_dst_d);
239 }
240 };
241
242 ip_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
243
244 status_t init(engine_t *engine) override {
245 CHECK(pd()->ip_pd_->create_primitive(ip_p_, engine));
246 return status::success;
247 }
248
249 status_t execute(const exec_ctx_t &ctx) const override;
250
251private:
252 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
253 std::shared_ptr<primitive_t> ip_p_;
254};
255
256struct ip_convolution_bwd_data_t : public primitive_t {
257 struct pd_t : public cpu_convolution_bwd_data_pd_t {
258 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
259 const convolution_fwd_pd_t *hint_fwd_pd)
260 : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {}
261
262 pd_t(const pd_t &other)
263 : cpu_convolution_bwd_data_pd_t(other)
264 , ip_pd_(other.ip_pd_->clone()) {}
265
266 ~pd_t() = default;
267
268 DECLARE_COMMON_PD_T(name_.c_str(), ip_convolution_bwd_data_t);
269
270 status_t init_ip(engine_t *engine) {
271 inner_product_desc_t ipd;
272 CHECK(ip_desc_create(&ipd));
273 primitive_desc_iterator_t it(
274 engine, (op_desc_t *)&ipd, attr(), nullptr);
275 if (!it.is_initialized()) return status::out_of_memory;
276 while (++it != it.end()) {
277 ip_pd_ = *it;
278 const bool ok = ip_pd_->weights_md()->extra.flags == 0;
279 if (ok) return status::success;
280 }
281 return status::unimplemented;
282 }
283
284 status_t init(engine_t *engine) {
285 using namespace format_tag;
286
287 const bool ok = desc()->prop_kind == prop_kind::backward_data
288 && set_default_alg_kind(alg_kind::convolution_direct)
289 && attr()->has_default_values();
290 if (!ok) return status::unimplemented;
291
292 CHECK(check_conv_ip(this));
293
294 CHECK(set_and_or_check_formats(*desc(), diff_src_md_, weights_md_,
295 diff_dst_md_, bias_md_, attr_));
296
297 CHECK(init_ip(engine));
298
299 if (weights_md_.format_kind == format_kind::any)
300 CHECK(maybe_reshape_weights(
301 &weights_md_, ip_pd_->weights_md(), with_groups()));
302
303 init_name();
304 init_scratchpad();
305 return status::success;
306 }
307
308 std::shared_ptr<primitive_desc_t> ip_pd_;
309
310 private:
311 std::string name_ = "ip:";
312
313 void init_name() { name_.append(ip_pd_->name()); }
314
315 void init_scratchpad() {
316 using namespace memory_tracking::names;
317 auto scratchpad = scratchpad_registry().registrar();
318 scratchpad.book(key_nested, ip_pd_->scratchpad_registry());
319 }
320
321 status_t ip_desc_create(inner_product_desc_t *ipd) {
322 const bool to_ip = true;
323
324 // reinterpret dst without spatial
325 memory_desc_t ip_diff_dst_d;
326 CHECK(reshape_dst(&ip_diff_dst_d, &diff_dst_md_));
327
328 // reinterpret weights without groups
329 memory_desc_t ip_weights_d;
330 CHECK(maybe_reshape_weights(
331 &ip_weights_d, &weights_md_, with_groups(), to_ip));
332
333 return ip_desc_init(ipd, desc()->prop_kind, &diff_src_md_,
334 &ip_weights_d, nullptr, &ip_diff_dst_d);
335 }
336 };
337
338 ip_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
339
340 status_t init(engine_t *engine) override {
341 CHECK(pd()->ip_pd_->create_primitive(ip_p_, engine));
342 return status::success;
343 }
344
345 status_t execute(const exec_ctx_t &ctx) const override;
346
347private:
348 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
349 std::shared_ptr<primitive_t> ip_p_;
350};
351
352struct ip_convolution_bwd_weights_t : public primitive_t {
353 struct pd_t : public cpu_convolution_bwd_weights_pd_t {
354 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
355 const convolution_fwd_pd_t *hint_fwd_pd)
356 : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {}
357
358 pd_t(const pd_t &other)
359 : cpu_convolution_bwd_weights_pd_t(other)
360 , ip_pd_(other.ip_pd_->clone()) {}
361
362 ~pd_t() = default;
363
364 DECLARE_COMMON_PD_T(name_.c_str(), ip_convolution_bwd_weights_t);
365
366 status_t init_ip(engine_t *engine) {
367 inner_product_desc_t ipd;
368 CHECK(ip_desc_create(&ipd));
369 primitive_desc_iterator_t it(
370 engine, (op_desc_t *)&ipd, attr(), nullptr);
371 if (!it.is_initialized()) return status::out_of_memory;
372
373 while (++it != it.end()) {
374 ip_pd_ = *it;
375 const bool ok = ip_pd_->weights_md()->extra.flags == 0;
376 if (ok) return status::success;
377 }
378 return status::unimplemented;
379 }
380
381 status_t init(engine_t *engine) {
382 using namespace format_tag;
383
384 const bool ok = desc()->prop_kind == prop_kind::backward_weights
385 && set_default_alg_kind(alg_kind::convolution_direct)
386 && attr()->has_default_values();
387 if (!ok) return status::unimplemented;
388
389 CHECK(check_conv_ip(this));
390
391 CHECK(set_and_or_check_formats(*desc(), src_md_, diff_weights_md_,
392 diff_dst_md_, diff_bias_md_, attr_));
393
394 CHECK(init_ip(engine));
395
396 if (diff_weights_md_.format_kind == format_kind::any)
397 CHECK(maybe_reshape_weights(&diff_weights_md_,
398 ip_pd_->diff_weights_md(), with_groups()));
399
400 init_name();
401 init_scratchpad();
402 return status::success;
403 }
404
405 std::shared_ptr<primitive_desc_t> ip_pd_;
406
407 private:
408 std::string name_ = "ip:";
409
410 void init_name() { name_.append(ip_pd_->name()); }
411
412 void init_scratchpad() {
413 using namespace memory_tracking::names;
414 auto scratchpad = scratchpad_registry().registrar();
415 scratchpad.book(key_nested, ip_pd_->scratchpad_registry());
416 }
417
418 status_t ip_desc_create(inner_product_desc_t *ipd) {
419 const bool to_ip = true;
420
421 // reinterpret dst without spatial
422 memory_desc_t ip_diff_dst_d;
423 CHECK(reshape_dst(&ip_diff_dst_d, &diff_dst_md_));
424
425 // reinterpret weights without groups
426 memory_desc_t ip_diff_weights_d;
427 CHECK(maybe_reshape_weights(&ip_diff_weights_d, &diff_weights_md_,
428 with_groups(), to_ip));
429
430 return ip_desc_init(ipd, desc()->prop_kind, &src_md_,
431 &ip_diff_weights_d, &diff_bias_md_, &ip_diff_dst_d);
432 }
433 };
434
435 ip_convolution_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {}
436
437 status_t init(engine_t *engine) override {
438 CHECK(pd()->ip_pd_->create_primitive(ip_p_, engine));
439 return status::success;
440 }
441
442 status_t execute(const exec_ctx_t &ctx) const override;
443
444private:
445 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
446 std::shared_ptr<primitive_t> ip_p_;
447};
448
449} // namespace x64
450} // namespace cpu
451} // namespace impl
452} // namespace dnnl
453
454#endif
455
456// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
457