1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/compiler_workarounds.hpp" |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "common/nstl.hpp" |
21 | #include "common/type_helpers.hpp" |
22 | #include "common/utils.hpp" |
23 | |
24 | #include "cpu/x64/jit_brgemm_conv_bwd.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | namespace x64 { |
30 | |
31 | namespace { |
32 | status_t weights_axes_permutation( |
33 | memory_desc_t *o_md, const memory_desc_t *i_md, bool with_groups) { |
34 | int perm[DNNL_MAX_NDIMS] {}; // bwd conv to fwd conv weight permutation |
35 | for (int d = 0; d < DNNL_MAX_NDIMS; ++d) |
36 | perm[d] = d; |
37 | nstl::swap(perm[0 + with_groups], perm[1 + with_groups]); |
38 | |
39 | return memory_desc_permute_axes(*o_md, *i_md, perm); |
40 | } |
41 | |
42 | status_t fwd_conv_desc_create( |
43 | convolution_desc_t *fwd_conv_d, const convolution_desc_t *bwd_conv_d) { |
44 | // create a new weights descriptor with OC and IC transposed; |
45 | // spatial inversion is handled by inverting indices on-the-fly |
46 | memory_desc_t fwd_weights_md; |
47 | const memory_desc_t &bwd_weights_md = bwd_conv_d->weights_desc; |
48 | const bool with_groups |
49 | = bwd_weights_md.ndims == bwd_conv_d->diff_src_desc.ndims + 1; |
50 | CHECK(weights_axes_permutation( |
51 | &fwd_weights_md, &bwd_weights_md, with_groups)); |
52 | |
53 | // create a fwd convolution descriptor with padding adjusted |
54 | // to the perspective of backward propagation, namely: |
55 | // - left padding replaced by left overflow |
56 | // - right padding replaced by right overflow |
57 | const int ndims_spatial = bwd_conv_d->diff_src_desc.ndims - 2; |
58 | dims_t overflow_l; |
59 | dims_t overflow_r; |
60 | dim_t ks = 1; |
61 | for (int i = 0; i < ndims_spatial; i++) { |
62 | // only unit strides are allowed for bwd-to-fwd conversion |
63 | if (bwd_conv_d->strides[i] != 1) return status::unimplemented; |
64 | const dim_t K |
65 | = bwd_weights_md.dims[bwd_weights_md.ndims - ndims_spatial + i]; |
66 | ks *= K; |
67 | const dim_t D = bwd_conv_d->dilates[i]; |
68 | const dim_t PL = bwd_conv_d->padding[0][i]; // left padding |
69 | const dim_t PR = bwd_conv_d->padding[1][i]; // right padding |
70 | constexpr dim_t S = 1; |
71 | // the following relations hold for unit stride only |
72 | overflow_l[i] = ((K - 1) * (D + 1) - PL) / S; |
73 | overflow_r[i] = ((K - 1) * (D + 1) - PR) / S; |
74 | } |
75 | |
76 | CHECK(conv_desc_init(fwd_conv_d, prop_kind::forward_training, |
77 | alg_kind::convolution_direct, &bwd_conv_d->diff_dst_desc, |
78 | &fwd_weights_md, &bwd_conv_d->bias_desc, &bwd_conv_d->diff_src_desc, |
79 | bwd_conv_d->strides, bwd_conv_d->dilates, overflow_l, overflow_r)); |
80 | |
81 | // HACK: Set diff_src_desc and diff_dst_desc as a signal to the primitive |
82 | // descriptor cache that we are using the bwd-via-fwd version of |
83 | // fwd conv and thus need a separate cache entry. Only needed for |
84 | // non-1x1 convs due to spatial inversion of weights. This assumes |
85 | // that external users only use the API to create conv descs, and |
86 | // relies on common/convolution.cpp only setting the expected mem descs. |
87 | // TODO: Pass this information via attributes or integrate the bwd-via-fwd |
88 | // method directly into fwd conv implementations. |
89 | const bool with_spatial_inversion = ks > 1; |
90 | if (with_spatial_inversion) { |
91 | fwd_conv_d->diff_src_desc = fwd_conv_d->src_desc; |
92 | fwd_conv_d->diff_dst_desc = fwd_conv_d->dst_desc; |
93 | } |
94 | return status::success; |
95 | } |
96 | } // namespace |
97 | |
98 | template <cpu_isa_t isa> |
99 | status_t brgemm_convolution_bwd_t<isa>::pd_t::init(engine_t *engine) { |
100 | using namespace data_type; |
101 | using namespace utils; |
102 | |
103 | const bool ok = is_bwd_d() |
104 | && set_default_alg_kind(alg_kind::convolution_direct) |
105 | && attr()->has_default_values() && !has_zero_dim_memory(); |
106 | if (!ok) return status::unimplemented; |
107 | |
108 | convolution_desc_t fwd_conv_d = convolution_desc_t(); |
109 | CHECK(fwd_conv_desc_create(&fwd_conv_d, desc())); |
110 | |
111 | primitive_desc_t *pd; |
112 | do { |
113 | // try creating fwd 1x1 conv prim desc |
114 | using fwd_1x1_conv_pd_t = |
115 | typename brgemm_1x1_convolution_fwd_t<isa>::pd_t; |
116 | status_t s = primitive_desc_t::create<fwd_1x1_conv_pd_t>(&pd, |
117 | reinterpret_cast<const op_desc_t *>(&fwd_conv_d), attr(), |
118 | engine, nullptr); |
119 | if (s == status::success) break; |
120 | // try creating fwd conv prim desc |
121 | constexpr bool use_inversion = true; // invert weights' spatial indices |
122 | using fwd_conv_pd_t = |
123 | typename brgemm_convolution_fwd_t<isa, use_inversion>::pd_t; |
124 | CHECK(primitive_desc_t::create<fwd_conv_pd_t>(&pd, |
125 | reinterpret_cast<const op_desc_t *>(&fwd_conv_d), attr(), |
126 | engine, nullptr)); |
127 | } while (false); |
128 | fwd_pd_.reset(pd); |
129 | |
130 | if (weights_md_.format_kind == format_kind::any) |
131 | CHECK(weights_axes_permutation( |
132 | &weights_md_, fwd_pd_->weights_md(), with_groups())); |
133 | if (diff_src_md_.format_kind == format_kind::any) |
134 | diff_src_md_ = *fwd_pd_->dst_md(); |
135 | if (diff_dst_md_.format_kind == format_kind::any) |
136 | diff_dst_md_ = *fwd_pd_->src_md(); |
137 | if (bias_md_.format_kind == format_kind::any) |
138 | bias_md_ = *fwd_pd_->weights_md(1); |
139 | |
140 | auto scratchpad = scratchpad_registry().registrar(); |
141 | scratchpad.book( |
142 | memory_tracking::names::key_nested, fwd_pd_->scratchpad_registry()); |
143 | |
144 | return status::success; |
145 | } |
146 | |
147 | template <cpu_isa_t isa> |
148 | status_t brgemm_convolution_bwd_t<isa>::init(engine_t *engine) { |
149 | return pd()->fwd_pd_->create_primitive(fwd_p_, engine); |
150 | } |
151 | |
152 | template <cpu_isa_t isa> |
153 | status_t brgemm_convolution_bwd_t<isa>::execute(const exec_ctx_t &ctx) const { |
154 | const auto &args = ctx.args(); |
155 | exec_args_t conv_args; |
156 | conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC); |
157 | conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST); |
158 | conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS); |
159 | if (pd()->with_bias()) conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS); |
160 | |
161 | exec_ctx_t fwd_ctx(ctx, std::move(conv_args)); |
162 | |
163 | nested_scratchpad_t ns(ctx, memory_tracking::names::key_nested, fwd_p_); |
164 | fwd_ctx.set_scratchpad_grantor(ns.grantor()); |
165 | return fwd_p_->execute(fwd_ctx); |
166 | } |
167 | |
168 | template struct brgemm_convolution_bwd_t<avx2_vnni_2>; |
169 | template struct brgemm_convolution_bwd_t<avx512_core>; |
170 | template struct brgemm_convolution_bwd_t<avx512_core_bf16>; |
171 | template struct brgemm_convolution_bwd_t<avx512_core_fp16>; |
172 | template struct brgemm_convolution_bwd_t<avx512_core_amx>; |
173 | |
174 | } // namespace x64 |
175 | } // namespace cpu |
176 | } // namespace impl |
177 | } // namespace dnnl |
178 | |
179 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
180 | |