1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "utils/parallel.hpp" |
18 | |
19 | #include "ip/ip.hpp" |
20 | |
21 | namespace ip { |
22 | |
23 | void compute_ref_fwd_ip(const prb_t *prb, const args_t &args) { |
24 | const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC); |
25 | const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS); |
26 | const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS); |
27 | const dnn_mem_t &dst_m = args.find(DNNL_ARG_DST); |
28 | |
29 | int64_t M = prb->mb; |
30 | int64_t N = prb->oc; |
31 | int64_t K = prb->ic * prb->id * prb->ih * prb->iw; |
32 | |
33 | dnn_mem_t dst_tmp(dst_m.md_, dnnl_f32, tag::abx, dst_m.engine()); |
34 | |
35 | gemm("C" , "N" , "T" , M, N, K, 1.f, (float *)src_m, K, (float *)wei_m, K, 0.f, |
36 | (float *)dst_tmp, N); |
37 | |
38 | auto v_po_masks = prb->attr.post_ops.get_po_masks(); |
39 | benchdnn_parallel_nd(prb->mb, prb->oc, [&](int64_t mb, int64_t oc) { |
40 | size_t dst_off = dst_off_f(prb, mb, oc); |
41 | float &dst = ((float *)dst_m)[dst_off]; |
42 | |
43 | float d = ((float *)dst_tmp)[dst_off]; |
44 | |
45 | maybe_scale(prb->attr, d, prb->src_scales, 0, DNNL_ARG_SRC); |
46 | maybe_scale(prb->attr, d, prb->wei_scales, oc, DNNL_ARG_WEIGHTS); |
47 | |
48 | if (prb->dir & FLAG_BIA) { |
49 | size_t bia_off = bia_off_f(prb, oc); |
50 | d += ((float *)bia_m)[bia_off]; |
51 | } |
52 | |
53 | const auto v_po_vals |
54 | = prepare_po_vals(dst_m, args, v_po_masks, dst_off); |
55 | |
56 | maybe_post_ops(prb->attr, d, dst, v_po_vals); |
57 | |
58 | maybe_scale(prb->attr, d, prb->dst_scales, oc, DNNL_ARG_DST, true); |
59 | dst = d; |
60 | }); |
61 | } |
62 | |
63 | void compute_ref_bwd_d_ip(const prb_t *prb, const args_t &args) { |
64 | const dnn_mem_t &diff_src_m = args.find(DNNL_ARG_DIFF_SRC); |
65 | const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS); |
66 | const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST); |
67 | |
68 | int64_t M = prb->mb; |
69 | int64_t N = prb->ic * prb->id * prb->ih * prb->iw; |
70 | int64_t K = prb->oc; |
71 | |
72 | gemm("C" , "N" , "N" , M, N, K, 1.f, (float *)diff_dst_m, K, (float *)wei_m, N, |
73 | 0.f, (float *)diff_src_m, N); |
74 | } |
75 | |
76 | void compute_ref_bwd_w_ip(const prb_t *prb, const args_t &args) { |
77 | const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC); |
78 | const dnn_mem_t &diff_wei_m = args.find(DNNL_ARG_DIFF_WEIGHTS); |
79 | const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST); |
80 | const dnn_mem_t &diff_bia_m = args.find(DNNL_ARG_DIFF_BIAS); |
81 | |
82 | int64_t M = prb->oc; |
83 | int64_t N = prb->ic * prb->id * prb->ih * prb->iw; |
84 | int64_t K = prb->mb; |
85 | |
86 | gemm("C" , "T" , "N" , M, N, K, 1.f, (float *)diff_dst_m, M, (float *)src_m, N, |
87 | 0.f, (float *)diff_wei_m, N); |
88 | |
89 | if (!(prb->dir & FLAG_BIA)) return; |
90 | |
91 | benchdnn_parallel_nd(prb->oc, [&](int64_t oc) { |
92 | size_t bia_off = bia_off_f(prb, oc); |
93 | float &db = ((float *)diff_bia_m)[bia_off]; |
94 | db = 0; |
95 | for (int64_t mb = 0; mb < prb->mb; ++mb) { |
96 | size_t dst_off = dst_off_f(prb, mb, oc); |
97 | db += ((float *)diff_dst_m)[dst_off]; |
98 | } |
99 | }); |
100 | } |
101 | |
102 | void compute_ref_fwd( |
103 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
104 | if (prim_ref) { |
105 | SAFE_V(execute_and_wait(prim_ref, args)); |
106 | return; |
107 | } |
108 | |
109 | compute_ref_fwd_ip(prb, args); |
110 | } |
111 | |
112 | void compute_ref_bwd_d( |
113 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
114 | if (prim_ref) { |
115 | SAFE_V(execute_and_wait(prim_ref, args)); |
116 | return; |
117 | } |
118 | |
119 | compute_ref_bwd_d_ip(prb, args); |
120 | } |
121 | |
122 | void compute_ref_bwd_w( |
123 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
124 | if (prim_ref) { |
125 | SAFE_V(execute_and_wait(prim_ref, args)); |
126 | return; |
127 | } |
128 | |
129 | compute_ref_bwd_w_ip(prb, args); |
130 | } |
131 | |
132 | void compute_ref( |
133 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
134 | if (prb->dir & FLAG_FWD) |
135 | compute_ref_fwd(prb, args, prim_ref); |
136 | else if (prb->dir == BWD_D) |
137 | compute_ref_bwd_d(prb, args, prim_ref); |
138 | else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI) |
139 | compute_ref_bwd_w(prb, args, prim_ref); |
140 | } |
141 | |
142 | } // namespace ip |
143 | |