1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/cpu_primitive.hpp"
23#include "cpu/scale_utils.hpp"
24
25#include "cpu/x64/jit_avx512_core_amx_conv_utils.hpp"
26#include "cpu/x64/jit_avx512_core_amx_deconvolution.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33using namespace dnnl::impl::memory_tracking::names;
34
35#define wht_blk_off(d, g, ...) \
36 (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \
37 : (d).blk_off(__VA_ARGS__))
38
39// NOTE: This primitive shares a kernel with bwd/d convolution. Hence, all
40// parameters stored in `pd()->jcp_` are in terms of bwd/d convolution.
41// This means that the following parameters have been exchanged:
42// 1. ic <-> oc
43// 2. ih <-> oh
44// 3. iw <-> ow
45// The same exchange applies to all derivative values in `pd()->jcp_`
46// (eg, ic_block <-> oc_block, etc).
47
48void jit_avx512_core_amx_deconvolution_fwd_t::prepare_padded_bias(
49 const char *&bias, const memory_tracking::grantor_t &scratchpad) const {
50 auto &jcp = pd()->jcp_;
51 if (jcp.with_bias && jcp.ic != jcp.ic_without_padding) {
52 const size_t bia_dt_size = jcp.typesize_bia;
53 auto padded_bias = scratchpad.template get<char>(
54 memory_tracking::names::key_conv_padded_bias);
55 utils::array_copy(
56 padded_bias, bias, bia_dt_size * jcp.ic_without_padding);
57 utils::array_set(padded_bias + bia_dt_size * jcp.ic_without_padding,
58 0.f, bia_dt_size * (jcp.ic - jcp.ic_without_padding));
59 bias = padded_bias;
60 }
61}
62
63status_t jit_avx512_core_amx_deconvolution_fwd_t::execute_forward(
64 const exec_ctx_t &ctx) const {
65 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
66 auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
67 auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
68 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
69
70 const memory_desc_wrapper src_d(pd()->src_md());
71 const memory_desc_wrapper weights_d(pd()->weights_md(0));
72 const memory_desc_wrapper bias_d(pd()->weights_md(1));
73 const memory_desc_wrapper dst_d(pd()->dst_md());
74
75 prepare_padded_bias(bias, ctx.get_scratchpad_grantor());
76
77 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
78 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
79 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
80
81 const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
82 src_scales, wei_scales, dst_d.dims()[1], pd()->attr());
83
84 // The body of bwd/d convolution harness is called with:
85 // 1. src as input instead of diff_dst
86 // 2. dst as output instead of diff_src
87 amx_utils::execute_backward_convolution_body(ctx, pd()->jcp_, kernel_, src,
88 weights, bias, oscales, dst_scales, dst, src_d, weights_d, bias_d,
89 dst_d);
90
91 return status::success;
92}
93
94} // namespace x64
95} // namespace cpu
96} // namespace impl
97} // namespace dnnl
98
99// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
100