1/*******************************************************************************
2* Copyright 2016-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20
21#include "cpu/ref_io_helper.hpp"
22
23#include "cpu/ref_inner_product.hpp"
24#include "cpu/ref_inner_product_utils.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29
30status_t ref_inner_product_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
31 status_t status = status::success;
32 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
33 auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS);
34 auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
35 auto dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status);
36 CHECK(status);
37
38 const memory_desc_wrapper src_d(pd()->src_md());
39 const memory_desc_wrapper dst_d(pd()->dst_md());
40 const memory_desc_wrapper weights_d(pd()->weights_md(0));
41 const memory_desc_wrapper bias_d(pd()->weights_md(1));
42
43 const auto ndims = pd()->ndims();
44 const auto MB = pd()->MB();
45 const auto OC = pd()->OC();
46 const auto IC = pd()->IC();
47
48 auto ker = [=](dim_t mb, dim_t oc) {
49 float d = 0;
50 const dim_t KD = pd()->KD();
51 const dim_t KH = pd()->KH();
52 const dim_t KW = pd()->KW();
53 for_(dim_t ic = 0; ic < IC; ++ic)
54 for_(dim_t kd = 0; kd < KD; ++kd)
55 for_(dim_t kh = 0; kh < KH; ++kh)
56 for (dim_t kw = 0; kw < KW; ++kw) {
57 const auto src_off = ref_ip_utils::get_data_off(
58 src_d, ndims, mb, ic, kd, kh, kw);
59 const auto wei_off = ref_ip_utils::get_weights_off(
60 weights_d, ndims, oc, ic, kd, kh, kw);
61 const float s
62 = io::load_float_value(src_d.data_type(), src, src_off);
63 const float w = io::load_float_value(
64 weights_d.data_type(), weights, wei_off);
65 d += s * w;
66 }
67 return d;
68 };
69
70 parallel_nd(MB, OC, [&](dim_t mb, dim_t oc) {
71 float acc = ker(mb, oc);
72
73 float d = acc;
74 if (bias) {
75 const auto bias_off = bias_d.off(oc);
76 const float b
77 = io::load_float_value(bias_d.data_type(), bias, bias_off);
78 d += b;
79 }
80
81 dim_t dst_off = dst_d.off(mb, oc);
82 dim_t dst_l_off = (mb * OC + oc);
83
84 ref_post_ops_t::args_t args;
85 args.dst_val = io::load_float_value(dst_d.data_type(), dst, dst_off);
86 args.ctx = &ctx;
87 args.l_offset = dst_l_off;
88 args.dst_md = pd()->dst_md();
89 ref_post_ops->execute(d, args);
90
91 io::store_float_value(dst_d.data_type(), d, dst, dst_off);
92 });
93
94 return status::success;
95}
96
97status_t ref_inner_product_bwd_data_t::execute_backward_data(
98 const exec_ctx_t &ctx) const {
99 status_t status = status::success;
100 auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
101 auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS);
102 auto diff_src = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_SRC, status);
103 CHECK(status);
104
105 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
106 const memory_desc_wrapper weights_d(pd()->weights_md(0));
107 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
108
109 const auto ndims = pd()->ndims();
110 const auto MB = pd()->MB();
111 const auto OC = pd()->OC();
112 const auto IC = pd()->IC();
113
114 parallel_nd(MB, IC, [&](dim_t mb, dim_t ic) {
115 const dim_t KD = pd()->KD();
116 const dim_t KH = pd()->KH();
117 const dim_t KW = pd()->KW();
118 for_(dim_t kd = 0; kd < KD; ++kd)
119 for_(dim_t kh = 0; kh < KH; ++kh)
120 for (dim_t kw = 0; kw < KW; ++kw) {
121 float ds = 0;
122 for (dim_t oc = 0; oc < OC; ++oc) {
123 const auto diff_dst_off = ref_ip_utils::get_data_off(
124 diff_dst_d, 2, mb, oc, 0, 0, 0);
125 const auto wei_off = ref_ip_utils::get_weights_off(
126 weights_d, ndims, oc, ic, kd, kh, kw);
127 const float dd = io::load_float_value(
128 diff_dst_d.data_type(), diff_dst, diff_dst_off);
129 const float w = io::load_float_value(
130 weights_d.data_type(), weights, wei_off);
131 ds += dd * w;
132 }
133 const auto diff_src_off = ref_ip_utils::get_data_off(
134 diff_src_d, ndims, mb, ic, kd, kh, kw);
135 io::store_float_value(
136 diff_src_d.data_type(), ds, diff_src, diff_src_off);
137 }
138 });
139
140 return status::success;
141}
142
143status_t ref_inner_product_bwd_weights_t::execute_backward_weights(
144 const exec_ctx_t &ctx) const {
145 status_t status = status::success;
146 auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
147 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
148 auto diff_weights
149 = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_WEIGHTS, status);
150 CHECK(status);
151 auto diff_bias = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_BIAS, status);
152 CHECK(status);
153
154 const memory_desc_wrapper src_d(pd()->src_md());
155 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
156 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
157 const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1));
158
159 const auto ndims = src_d.ndims();
160 const auto MB = pd()->MB();
161 const auto OC = pd()->OC();
162 const auto IC = pd()->IC();
163
164 parallel_nd(OC, IC, [&](dim_t oc, dim_t ic) {
165 const dim_t KD = pd()->KD();
166 const dim_t KH = pd()->KH();
167 const dim_t KW = pd()->KW();
168 for_(dim_t kd = 0; kd < KD; ++kd)
169 for_(dim_t kh = 0; kh < KH; ++kh)
170 for (dim_t kw = 0; kw < KW; ++kw) {
171 float dw = 0;
172 for (dim_t mb = 0; mb < MB; ++mb) {
173 const auto diff_dst_off = ref_ip_utils::get_data_off(
174 diff_dst_d, 2, mb, oc, 0, 0, 0);
175 const auto src_off = ref_ip_utils::get_data_off(
176 src_d, ndims, mb, ic, kd, kh, kw);
177 const float dd = io::load_float_value(
178 diff_dst_d.data_type(), diff_dst, diff_dst_off);
179 const float s
180 = io::load_float_value(src_d.data_type(), src, src_off);
181 dw += dd * s;
182 }
183 const auto diff_wei_off = ref_ip_utils::get_weights_off(
184 diff_weights_d, ndims, oc, ic, kd, kh, kw);
185 io::store_float_value(
186 diff_weights_d.data_type(), dw, diff_weights, diff_wei_off);
187 }
188 });
189
190 if (diff_bias) {
191 parallel_nd(OC, [&](dim_t oc) {
192 float db = 0;
193 for (dim_t mb = 0; mb < MB; ++mb) {
194 const auto diff_dst_off = ref_ip_utils::get_data_off(
195 diff_dst_d, 2, mb, oc, 0, 0, 0);
196 const float dd = io::load_float_value(
197 diff_dst_d.data_type(), diff_dst, diff_dst_off);
198 db += dd;
199 }
200
201 const auto diff_bia_off = diff_bias_d.off(oc);
202 io::store_float_value(
203 diff_bias_d.data_type(), db, diff_bias, diff_bia_off);
204 });
205 }
206
207 return status::success;
208}
209
210} // namespace cpu
211} // namespace impl
212} // namespace dnnl
213
214// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
215