1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20
21#include "cpu/binary_injector_utils.hpp"
22#include "cpu/cpu_primitive.hpp"
23#include "cpu/gemm_inner_product.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28
29using namespace dnnl::impl::status;
30using namespace dnnl::impl::prop_kind;
31using namespace dnnl::impl::data_type;
32using namespace dnnl::impl::format_tag;
33using namespace dnnl::impl::primitive_kind;
34
35template <impl::data_type_t data_type>
36status_t gemm_inner_product_fwd_t<data_type>::execute_forward(
37 const exec_ctx_t &ctx) const {
38 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
39 auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
40 auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS);
41 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
42 const auto post_ops_binary_rhs_arg_vec
43 = binary_injector_utils::prepare_binary_args(
44 this->pd()->attr()->post_ops_, ctx);
45
46 const dim_t MB = pd()->MB();
47 const dim_t OC = pd()->OC();
48 const dim_t IC = pd()->IC_total_padded();
49
50 const auto &wmd = *pd()->weights_md();
51 const auto &smd = *pd()->src_md();
52 // check if OC is NOT the leading dimension
53 bool wei_tr = wmd.format_desc.blocking.strides[0] != 1;
54 // check if MB is the leading dimension
55 bool src_tr = smd.format_desc.blocking.strides[0] == 1 && IC > 1;
56
57 float alpha = 1.;
58 status_t st = extended_sgemm(wei_tr ? "T" : "N", src_tr ? "T" : "N", &OC,
59 &MB, &IC, &alpha, weights, wei_tr ? &IC : &OC, src,
60 src_tr ? &MB : &IC, &beta_, dst, &OC,
61 postops_in_ip_ ? nullptr : bias);
62
63 if (st != status::success) return st;
64
65 if (postops_in_ip_) {
66 const bool force_sequential = pp_kernel_->sequential_kernel();
67 parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) {
68 size_t start, end;
69 balance211((size_t)(OC * MB), nthr, ithr, start, end);
70 const size_t dim1_off = start % OC;
71 (*pp_kernel_)(dst, dst, (char *)bias, nullptr, 1.0f, start, start,
72 dim1_off, end, 0,
73 pd()->OC() * pd()->OD() * pd()->OH() * pd()->OW(), nullptr,
74 post_ops_binary_rhs_arg_vec.data(), dst, 0, ctx,
75 *pd()->dst_md());
76 });
77 }
78
79 return status::success;
80}
81
82template <impl::data_type_t data_type>
83status_t gemm_inner_product_bwd_data_t<data_type>::execute_backward_data(
84 const exec_ctx_t &ctx) const {
85 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
86 auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
87 auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
88
89 const dim_t MB = pd()->MB();
90 const dim_t OC = pd()->OC();
91 const dim_t IC = pd()->IC_total_padded();
92
93 const auto &wmd = *pd()->weights_md();
94 const auto &smd = *pd()->diff_src_md();
95 bool wei_tr = wmd.format_desc.blocking.strides[0] == 1;
96 // check if MB is the leading dimension
97 bool dsrc_tr = smd.format_desc.blocking.strides[0] == 1 && IC > 1;
98
99 float alpha = 1.0, beta = 0.0;
100 status_t st = status::success;
101 if (dsrc_tr)
102 st = extended_sgemm(wei_tr ? "T" : "N", "N", &OC, &IC, &MB, &alpha,
103 diff_dst, &OC, weights, wei_tr ? &OC : &IC, &beta, diff_src,
104 &MB);
105 else
106 st = extended_sgemm(wei_tr ? "T" : "N", "N", &IC, &MB, &OC, &alpha,
107 weights, wei_tr ? &OC : &IC, diff_dst, &OC, &beta, diff_src,
108 &IC);
109
110 return st;
111}
112
113template <impl::data_type_t data_type>
114status_t gemm_inner_product_bwd_weights_t<data_type>::execute_backward_weights(
115 const exec_ctx_t &ctx) const {
116 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
117 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
118 auto diff_weights = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_WEIGHTS);
119 auto diff_bias = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_BIAS);
120
121 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
122 const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1));
123
124 diff_dst += diff_dst_d.offset0();
125
126 const dim_t MB = pd()->MB();
127 const dim_t OC = pd()->OC();
128 const dim_t IC = pd()->IC_total_padded();
129
130 const auto &wmd = *pd()->diff_weights_md();
131 const auto &smd = *pd()->src_md();
132 bool wei_tr = wmd.format_desc.blocking.strides[0] == 1;
133 // check if MB is the leading dimension
134 bool src_tr = smd.format_desc.blocking.strides[0] == 1 && IC > 1;
135
136 float alpha = 1.0, beta = 0.0;
137 status_t st = status::success;
138 if (wei_tr)
139 st = extended_sgemm("N", src_tr ? "N" : "T", &OC, &IC, &MB, &alpha,
140 diff_dst, &OC, src, src_tr ? &MB : &IC, &beta, diff_weights,
141 &OC);
142 else
143 st = extended_sgemm("N", src_tr ? "N" : "T", &IC, &OC, &MB, &alpha, src,
144 src_tr ? &MB : &IC, diff_dst, &OC, &beta, diff_weights, &IC);
145
146 if (st != status::success) return st;
147
148 if (diff_bias) {
149 diff_bias += diff_bias_d.offset0();
150 constexpr dim_t blksize = 8;
151 const dim_t OC_blocks = utils::div_up(OC, blksize);
152 parallel(0, [&](const int ithr, const int nthr) {
153 dim_t oc_s {0}, oc_e {0};
154 balance211(OC_blocks, nthr, ithr, oc_s, oc_e);
155 oc_s = std::min(oc_s * blksize, OC);
156 oc_e = std::min(oc_e * blksize, OC);
157
158 PRAGMA_OMP_SIMD()
159 for (dim_t oc = oc_s; oc < oc_e; ++oc) {
160 diff_bias[oc] = diff_dst[oc];
161 }
162
163 for (dim_t mb = 1; mb < MB; ++mb) {
164 PRAGMA_OMP_SIMD()
165 for (dim_t oc = oc_s; oc < oc_e; ++oc) {
166 diff_bias[oc] += diff_dst[mb * OC + oc];
167 }
168 }
169 });
170 }
171
172 return status::success;
173}
174
175template struct gemm_inner_product_fwd_t<data_type::f32>;
176template struct gemm_inner_product_bwd_data_t<data_type::f32>;
177template struct gemm_inner_product_bwd_weights_t<data_type::f32>;
178
179} // namespace cpu
180} // namespace impl
181} // namespace dnnl
182
183// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
184