1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "utils/parallel.hpp"
18
19#include "ip/ip.hpp"
20
21namespace ip {
22
23void compute_ref_fwd_ip(const prb_t *prb, const args_t &args) {
24 const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC);
25 const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS);
26 const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS);
27 const dnn_mem_t &dst_m = args.find(DNNL_ARG_DST);
28
29 int64_t M = prb->mb;
30 int64_t N = prb->oc;
31 int64_t K = prb->ic * prb->id * prb->ih * prb->iw;
32
33 dnn_mem_t dst_tmp(dst_m.md_, dnnl_f32, tag::abx, dst_m.engine());
34
35 gemm("C", "N", "T", M, N, K, 1.f, (float *)src_m, K, (float *)wei_m, K, 0.f,
36 (float *)dst_tmp, N);
37
38 auto v_po_masks = prb->attr.post_ops.get_po_masks();
39 benchdnn_parallel_nd(prb->mb, prb->oc, [&](int64_t mb, int64_t oc) {
40 size_t dst_off = dst_off_f(prb, mb, oc);
41 float &dst = ((float *)dst_m)[dst_off];
42
43 float d = ((float *)dst_tmp)[dst_off];
44
45 maybe_scale(prb->attr, d, prb->src_scales, 0, DNNL_ARG_SRC);
46 maybe_scale(prb->attr, d, prb->wei_scales, oc, DNNL_ARG_WEIGHTS);
47
48 if (prb->dir & FLAG_BIA) {
49 size_t bia_off = bia_off_f(prb, oc);
50 d += ((float *)bia_m)[bia_off];
51 }
52
53 const auto v_po_vals
54 = prepare_po_vals(dst_m, args, v_po_masks, dst_off);
55
56 maybe_post_ops(prb->attr, d, dst, v_po_vals);
57
58 maybe_scale(prb->attr, d, prb->dst_scales, oc, DNNL_ARG_DST, true);
59 dst = d;
60 });
61}
62
63void compute_ref_bwd_d_ip(const prb_t *prb, const args_t &args) {
64 const dnn_mem_t &diff_src_m = args.find(DNNL_ARG_DIFF_SRC);
65 const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS);
66 const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
67
68 int64_t M = prb->mb;
69 int64_t N = prb->ic * prb->id * prb->ih * prb->iw;
70 int64_t K = prb->oc;
71
72 gemm("C", "N", "N", M, N, K, 1.f, (float *)diff_dst_m, K, (float *)wei_m, N,
73 0.f, (float *)diff_src_m, N);
74}
75
76void compute_ref_bwd_w_ip(const prb_t *prb, const args_t &args) {
77 const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC);
78 const dnn_mem_t &diff_wei_m = args.find(DNNL_ARG_DIFF_WEIGHTS);
79 const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
80 const dnn_mem_t &diff_bia_m = args.find(DNNL_ARG_DIFF_BIAS);
81
82 int64_t M = prb->oc;
83 int64_t N = prb->ic * prb->id * prb->ih * prb->iw;
84 int64_t K = prb->mb;
85
86 gemm("C", "T", "N", M, N, K, 1.f, (float *)diff_dst_m, M, (float *)src_m, N,
87 0.f, (float *)diff_wei_m, N);
88
89 if (!(prb->dir & FLAG_BIA)) return;
90
91 benchdnn_parallel_nd(prb->oc, [&](int64_t oc) {
92 size_t bia_off = bia_off_f(prb, oc);
93 float &db = ((float *)diff_bia_m)[bia_off];
94 db = 0;
95 for (int64_t mb = 0; mb < prb->mb; ++mb) {
96 size_t dst_off = dst_off_f(prb, mb, oc);
97 db += ((float *)diff_dst_m)[dst_off];
98 }
99 });
100}
101
102void compute_ref_fwd(
103 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
104 if (prim_ref) {
105 SAFE_V(execute_and_wait(prim_ref, args));
106 return;
107 }
108
109 compute_ref_fwd_ip(prb, args);
110}
111
112void compute_ref_bwd_d(
113 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
114 if (prim_ref) {
115 SAFE_V(execute_and_wait(prim_ref, args));
116 return;
117 }
118
119 compute_ref_bwd_d_ip(prb, args);
120}
121
122void compute_ref_bwd_w(
123 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
124 if (prim_ref) {
125 SAFE_V(execute_and_wait(prim_ref, args));
126 return;
127 }
128
129 compute_ref_bwd_w_ip(prb, args);
130}
131
132void compute_ref(
133 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
134 if (prb->dir & FLAG_FWD)
135 compute_ref_fwd(prb, args, prim_ref);
136 else if (prb->dir == BWD_D)
137 compute_ref_bwd_d(prb, args, prim_ref);
138 else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI)
139 compute_ref_bwd_w(prb, args, prim_ref);
140}
141
142} // namespace ip
143