1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/compiler_workarounds.hpp"
19#include "common/dnnl_thread.hpp"
20#include "common/nstl.hpp"
21#include "common/type_helpers.hpp"
22#include "common/utils.hpp"
23
24#include "cpu/x64/jit_brgemm_conv_bwd.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29namespace x64 {
30
31namespace {
32status_t weights_axes_permutation(
33 memory_desc_t *o_md, const memory_desc_t *i_md, bool with_groups) {
34 int perm[DNNL_MAX_NDIMS] {}; // bwd conv to fwd conv weight permutation
35 for (int d = 0; d < DNNL_MAX_NDIMS; ++d)
36 perm[d] = d;
37 nstl::swap(perm[0 + with_groups], perm[1 + with_groups]);
38
39 return memory_desc_permute_axes(*o_md, *i_md, perm);
40}
41
42status_t fwd_conv_desc_create(
43 convolution_desc_t *fwd_conv_d, const convolution_desc_t *bwd_conv_d) {
44 // create a new weights descriptor with OC and IC transposed;
45 // spatial inversion is handled by inverting indices on-the-fly
46 memory_desc_t fwd_weights_md;
47 const memory_desc_t &bwd_weights_md = bwd_conv_d->weights_desc;
48 const bool with_groups
49 = bwd_weights_md.ndims == bwd_conv_d->diff_src_desc.ndims + 1;
50 CHECK(weights_axes_permutation(
51 &fwd_weights_md, &bwd_weights_md, with_groups));
52
53 // create a fwd convolution descriptor with padding adjusted
54 // to the perspective of backward propagation, namely:
55 // - left padding replaced by left overflow
56 // - right padding replaced by right overflow
57 const int ndims_spatial = bwd_conv_d->diff_src_desc.ndims - 2;
58 dims_t overflow_l;
59 dims_t overflow_r;
60 dim_t ks = 1;
61 for (int i = 0; i < ndims_spatial; i++) {
62 // only unit strides are allowed for bwd-to-fwd conversion
63 if (bwd_conv_d->strides[i] != 1) return status::unimplemented;
64 const dim_t K
65 = bwd_weights_md.dims[bwd_weights_md.ndims - ndims_spatial + i];
66 ks *= K;
67 const dim_t D = bwd_conv_d->dilates[i];
68 const dim_t PL = bwd_conv_d->padding[0][i]; // left padding
69 const dim_t PR = bwd_conv_d->padding[1][i]; // right padding
70 constexpr dim_t S = 1;
71 // the following relations hold for unit stride only
72 overflow_l[i] = ((K - 1) * (D + 1) - PL) / S;
73 overflow_r[i] = ((K - 1) * (D + 1) - PR) / S;
74 }
75
76 CHECK(conv_desc_init(fwd_conv_d, prop_kind::forward_training,
77 alg_kind::convolution_direct, &bwd_conv_d->diff_dst_desc,
78 &fwd_weights_md, &bwd_conv_d->bias_desc, &bwd_conv_d->diff_src_desc,
79 bwd_conv_d->strides, bwd_conv_d->dilates, overflow_l, overflow_r));
80
81 // HACK: Set diff_src_desc and diff_dst_desc as a signal to the primitive
82 // descriptor cache that we are using the bwd-via-fwd version of
83 // fwd conv and thus need a separate cache entry. Only needed for
84 // non-1x1 convs due to spatial inversion of weights. This assumes
85 // that external users only use the API to create conv descs, and
86 // relies on common/convolution.cpp only setting the expected mem descs.
87 // TODO: Pass this information via attributes or integrate the bwd-via-fwd
88 // method directly into fwd conv implementations.
89 const bool with_spatial_inversion = ks > 1;
90 if (with_spatial_inversion) {
91 fwd_conv_d->diff_src_desc = fwd_conv_d->src_desc;
92 fwd_conv_d->diff_dst_desc = fwd_conv_d->dst_desc;
93 }
94 return status::success;
95}
96} // namespace
97
98template <cpu_isa_t isa>
99status_t brgemm_convolution_bwd_t<isa>::pd_t::init(engine_t *engine) {
100 using namespace data_type;
101 using namespace utils;
102
103 const bool ok = is_bwd_d()
104 && set_default_alg_kind(alg_kind::convolution_direct)
105 && attr()->has_default_values() && !has_zero_dim_memory();
106 if (!ok) return status::unimplemented;
107
108 convolution_desc_t fwd_conv_d = convolution_desc_t();
109 CHECK(fwd_conv_desc_create(&fwd_conv_d, desc()));
110
111 primitive_desc_t *pd;
112 do {
113 // try creating fwd 1x1 conv prim desc
114 using fwd_1x1_conv_pd_t =
115 typename brgemm_1x1_convolution_fwd_t<isa>::pd_t;
116 status_t s = primitive_desc_t::create<fwd_1x1_conv_pd_t>(&pd,
117 reinterpret_cast<const op_desc_t *>(&fwd_conv_d), attr(),
118 engine, nullptr);
119 if (s == status::success) break;
120 // try creating fwd conv prim desc
121 constexpr bool use_inversion = true; // invert weights' spatial indices
122 using fwd_conv_pd_t =
123 typename brgemm_convolution_fwd_t<isa, use_inversion>::pd_t;
124 CHECK(primitive_desc_t::create<fwd_conv_pd_t>(&pd,
125 reinterpret_cast<const op_desc_t *>(&fwd_conv_d), attr(),
126 engine, nullptr));
127 } while (false);
128 fwd_pd_.reset(pd);
129
130 if (weights_md_.format_kind == format_kind::any)
131 CHECK(weights_axes_permutation(
132 &weights_md_, fwd_pd_->weights_md(), with_groups()));
133 if (diff_src_md_.format_kind == format_kind::any)
134 diff_src_md_ = *fwd_pd_->dst_md();
135 if (diff_dst_md_.format_kind == format_kind::any)
136 diff_dst_md_ = *fwd_pd_->src_md();
137 if (bias_md_.format_kind == format_kind::any)
138 bias_md_ = *fwd_pd_->weights_md(1);
139
140 auto scratchpad = scratchpad_registry().registrar();
141 scratchpad.book(
142 memory_tracking::names::key_nested, fwd_pd_->scratchpad_registry());
143
144 return status::success;
145}
146
147template <cpu_isa_t isa>
148status_t brgemm_convolution_bwd_t<isa>::init(engine_t *engine) {
149 return pd()->fwd_pd_->create_primitive(fwd_p_, engine);
150}
151
152template <cpu_isa_t isa>
153status_t brgemm_convolution_bwd_t<isa>::execute(const exec_ctx_t &ctx) const {
154 const auto &args = ctx.args();
155 exec_args_t conv_args;
156 conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC);
157 conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST);
158 conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
159 if (pd()->with_bias()) conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS);
160
161 exec_ctx_t fwd_ctx(ctx, std::move(conv_args));
162
163 nested_scratchpad_t ns(ctx, memory_tracking::names::key_nested, fwd_p_);
164 fwd_ctx.set_scratchpad_grantor(ns.grantor());
165 return fwd_p_->execute(fwd_ctx);
166}
167
168template struct brgemm_convolution_bwd_t<avx2_vnni_2>;
169template struct brgemm_convolution_bwd_t<avx512_core>;
170template struct brgemm_convolution_bwd_t<avx512_core_bf16>;
171template struct brgemm_convolution_bwd_t<avx512_core_fp16>;
172template struct brgemm_convolution_bwd_t<avx512_core_amx>;
173
174} // namespace x64
175} // namespace cpu
176} // namespace impl
177} // namespace dnnl
178
179// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
180