1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "utils/parallel.hpp" |
18 | |
19 | #include "binary/binary.hpp" |
20 | |
21 | namespace binary { |
22 | |
23 | void compute_ref( |
24 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
25 | const dnn_mem_t &src0 = args.find(DNNL_ARG_SRC_0); |
26 | const dnn_mem_t &src1 = args.find(DNNL_ARG_SRC_1); |
27 | const dnn_mem_t &dst = args.find(DNNL_ARG_DST); |
28 | |
29 | float *dst_ptr = (float *)dst; |
30 | const float *A = (const float *)src0; |
31 | const float *B = (const float *)src1; |
32 | |
33 | float scales[2] = {prb->attr.scales.get(DNNL_ARG_SRC_0).scale, |
34 | prb->attr.scales.get(DNNL_ARG_SRC_1).scale}; |
35 | |
36 | const auto nelems = dst.nelems(); |
37 | const auto broadcast_mask_A = prb->get_broadcast_mask(0); |
38 | const auto broadcast_mask_B = prb->get_broadcast_mask(1); |
39 | auto v_po_masks = prb->attr.post_ops.get_po_masks(); |
40 | |
41 | benchdnn_parallel_nd(nelems, [&](int64_t i) { |
42 | const auto idx_A = dst.get_scale_idx(i, broadcast_mask_A); |
43 | const auto idx_B = dst.get_scale_idx(i, broadcast_mask_B); |
44 | float res = compute_binary( |
45 | prb->alg, scales[0] * A[idx_A], scales[1] * B[idx_B]); |
46 | float &dst_fp = dst_ptr[i]; |
47 | |
48 | const auto v_po_vals = prepare_po_vals(dst, args, v_po_masks, i); |
49 | |
50 | maybe_post_ops( |
51 | prb->attr, res, maybe_saturate(prb->ddt, dst_fp), v_po_vals); |
52 | maybe_saturate(prb->ddt, res); |
53 | dst_fp = res; |
54 | }); |
55 | } |
56 | |
57 | } // namespace binary |
58 | |