1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | |
21 | #include "cpu/binary_injector_utils.hpp" |
22 | #include "cpu/cpu_primitive.hpp" |
23 | #include "cpu/gemm_inner_product.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | |
29 | using namespace dnnl::impl::status; |
30 | using namespace dnnl::impl::prop_kind; |
31 | using namespace dnnl::impl::data_type; |
32 | using namespace dnnl::impl::format_tag; |
33 | using namespace dnnl::impl::primitive_kind; |
34 | |
35 | template <impl::data_type_t data_type> |
36 | status_t gemm_inner_product_fwd_t<data_type>::execute_forward( |
37 | const exec_ctx_t &ctx) const { |
38 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
39 | auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); |
40 | auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS); |
41 | auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
42 | const auto post_ops_binary_rhs_arg_vec |
43 | = binary_injector_utils::prepare_binary_args( |
44 | this->pd()->attr()->post_ops_, ctx); |
45 | |
46 | const dim_t MB = pd()->MB(); |
47 | const dim_t OC = pd()->OC(); |
48 | const dim_t IC = pd()->IC_total_padded(); |
49 | |
50 | const auto &wmd = *pd()->weights_md(); |
51 | const auto &smd = *pd()->src_md(); |
52 | // check if OC is NOT the leading dimension |
53 | bool wei_tr = wmd.format_desc.blocking.strides[0] != 1; |
54 | // check if MB is the leading dimension |
55 | bool src_tr = smd.format_desc.blocking.strides[0] == 1 && IC > 1; |
56 | |
57 | float alpha = 1.; |
58 | status_t st = extended_sgemm(wei_tr ? "T" : "N" , src_tr ? "T" : "N" , &OC, |
59 | &MB, &IC, &alpha, weights, wei_tr ? &IC : &OC, src, |
60 | src_tr ? &MB : &IC, &beta_, dst, &OC, |
61 | postops_in_ip_ ? nullptr : bias); |
62 | |
63 | if (st != status::success) return st; |
64 | |
65 | if (postops_in_ip_) { |
66 | const bool force_sequential = pp_kernel_->sequential_kernel(); |
67 | parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) { |
68 | size_t start, end; |
69 | balance211((size_t)(OC * MB), nthr, ithr, start, end); |
70 | const size_t dim1_off = start % OC; |
71 | (*pp_kernel_)(dst, dst, (char *)bias, nullptr, 1.0f, start, start, |
72 | dim1_off, end, 0, |
73 | pd()->OC() * pd()->OD() * pd()->OH() * pd()->OW(), nullptr, |
74 | post_ops_binary_rhs_arg_vec.data(), dst, 0, ctx, |
75 | *pd()->dst_md()); |
76 | }); |
77 | } |
78 | |
79 | return status::success; |
80 | } |
81 | |
82 | template <impl::data_type_t data_type> |
83 | status_t gemm_inner_product_bwd_data_t<data_type>::execute_backward_data( |
84 | const exec_ctx_t &ctx) const { |
85 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
86 | auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); |
87 | auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); |
88 | |
89 | const dim_t MB = pd()->MB(); |
90 | const dim_t OC = pd()->OC(); |
91 | const dim_t IC = pd()->IC_total_padded(); |
92 | |
93 | const auto &wmd = *pd()->weights_md(); |
94 | const auto &smd = *pd()->diff_src_md(); |
95 | bool wei_tr = wmd.format_desc.blocking.strides[0] == 1; |
96 | // check if MB is the leading dimension |
97 | bool dsrc_tr = smd.format_desc.blocking.strides[0] == 1 && IC > 1; |
98 | |
99 | float alpha = 1.0, beta = 0.0; |
100 | status_t st = status::success; |
101 | if (dsrc_tr) |
102 | st = extended_sgemm(wei_tr ? "T" : "N" , "N" , &OC, &IC, &MB, &alpha, |
103 | diff_dst, &OC, weights, wei_tr ? &OC : &IC, &beta, diff_src, |
104 | &MB); |
105 | else |
106 | st = extended_sgemm(wei_tr ? "T" : "N" , "N" , &IC, &MB, &OC, &alpha, |
107 | weights, wei_tr ? &OC : &IC, diff_dst, &OC, &beta, diff_src, |
108 | &IC); |
109 | |
110 | return st; |
111 | } |
112 | |
113 | template <impl::data_type_t data_type> |
114 | status_t gemm_inner_product_bwd_weights_t<data_type>::execute_backward_weights( |
115 | const exec_ctx_t &ctx) const { |
116 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
117 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
118 | auto diff_weights = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_WEIGHTS); |
119 | auto diff_bias = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_BIAS); |
120 | |
121 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
122 | const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); |
123 | |
124 | diff_dst += diff_dst_d.offset0(); |
125 | |
126 | const dim_t MB = pd()->MB(); |
127 | const dim_t OC = pd()->OC(); |
128 | const dim_t IC = pd()->IC_total_padded(); |
129 | |
130 | const auto &wmd = *pd()->diff_weights_md(); |
131 | const auto &smd = *pd()->src_md(); |
132 | bool wei_tr = wmd.format_desc.blocking.strides[0] == 1; |
133 | // check if MB is the leading dimension |
134 | bool src_tr = smd.format_desc.blocking.strides[0] == 1 && IC > 1; |
135 | |
136 | float alpha = 1.0, beta = 0.0; |
137 | status_t st = status::success; |
138 | if (wei_tr) |
139 | st = extended_sgemm("N" , src_tr ? "N" : "T" , &OC, &IC, &MB, &alpha, |
140 | diff_dst, &OC, src, src_tr ? &MB : &IC, &beta, diff_weights, |
141 | &OC); |
142 | else |
143 | st = extended_sgemm("N" , src_tr ? "N" : "T" , &IC, &OC, &MB, &alpha, src, |
144 | src_tr ? &MB : &IC, diff_dst, &OC, &beta, diff_weights, &IC); |
145 | |
146 | if (st != status::success) return st; |
147 | |
148 | if (diff_bias) { |
149 | diff_bias += diff_bias_d.offset0(); |
150 | constexpr dim_t blksize = 8; |
151 | const dim_t OC_blocks = utils::div_up(OC, blksize); |
152 | parallel(0, [&](const int ithr, const int nthr) { |
153 | dim_t oc_s {0}, oc_e {0}; |
154 | balance211(OC_blocks, nthr, ithr, oc_s, oc_e); |
155 | oc_s = std::min(oc_s * blksize, OC); |
156 | oc_e = std::min(oc_e * blksize, OC); |
157 | |
158 | PRAGMA_OMP_SIMD() |
159 | for (dim_t oc = oc_s; oc < oc_e; ++oc) { |
160 | diff_bias[oc] = diff_dst[oc]; |
161 | } |
162 | |
163 | for (dim_t mb = 1; mb < MB; ++mb) { |
164 | PRAGMA_OMP_SIMD() |
165 | for (dim_t oc = oc_s; oc < oc_e; ++oc) { |
166 | diff_bias[oc] += diff_dst[mb * OC + oc]; |
167 | } |
168 | } |
169 | }); |
170 | } |
171 | |
172 | return status::success; |
173 | } |
174 | |
175 | template struct gemm_inner_product_fwd_t<data_type::f32>; |
176 | template struct gemm_inner_product_bwd_data_t<data_type::f32>; |
177 | template struct gemm_inner_product_bwd_weights_t<data_type::f32>; |
178 | |
179 | } // namespace cpu |
180 | } // namespace impl |
181 | } // namespace dnnl |
182 | |
183 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
184 | |