1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "utils/parallel.hpp"
18
19#include "binary/binary.hpp"
20
21namespace binary {
22
23void compute_ref(
24 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
25 const dnn_mem_t &src0 = args.find(DNNL_ARG_SRC_0);
26 const dnn_mem_t &src1 = args.find(DNNL_ARG_SRC_1);
27 const dnn_mem_t &dst = args.find(DNNL_ARG_DST);
28
29 float *dst_ptr = (float *)dst;
30 const float *A = (const float *)src0;
31 const float *B = (const float *)src1;
32
33 float scales[2] = {prb->attr.scales.get(DNNL_ARG_SRC_0).scale,
34 prb->attr.scales.get(DNNL_ARG_SRC_1).scale};
35
36 const auto nelems = dst.nelems();
37 const auto broadcast_mask_A = prb->get_broadcast_mask(0);
38 const auto broadcast_mask_B = prb->get_broadcast_mask(1);
39 auto v_po_masks = prb->attr.post_ops.get_po_masks();
40
41 benchdnn_parallel_nd(nelems, [&](int64_t i) {
42 const auto idx_A = dst.get_scale_idx(i, broadcast_mask_A);
43 const auto idx_B = dst.get_scale_idx(i, broadcast_mask_B);
44 float res = compute_binary(
45 prb->alg, scales[0] * A[idx_A], scales[1] * B[idx_B]);
46 float &dst_fp = dst_ptr[i];
47
48 const auto v_po_vals = prepare_po_vals(dst, args, v_po_masks, i);
49
50 maybe_post_ops(
51 prb->attr, res, maybe_saturate(prb->ddt, dst_fp), v_po_vals);
52 maybe_saturate(prb->ddt, res);
53 dst_fp = res;
54 });
55}
56
57} // namespace binary
58