1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <float.h>
19#include <math.h>
20
21#include "common/c_types_map.hpp"
22#include "common/dnnl_thread.hpp"
23#include "common/math_utils.hpp"
24#include "common/nstl.hpp"
25#include "common/type_helpers.hpp"
26
27#include "cpu/cpu_primitive.hpp"
28#include "cpu/ref_io_helper.hpp"
29#include "cpu/simple_q10n.hpp"
30
31#include "cpu/ref_binary.hpp"
32
33namespace dnnl {
34namespace impl {
35namespace cpu {
36
37status_t ref_binary_t::execute_ref(const exec_ctx_t &ctx) const {
38 const auto src0 = CTX_IN_MEM(const void *, DNNL_ARG_SRC_0);
39 const auto src1 = CTX_IN_MEM(const void *, DNNL_ARG_SRC_1);
40 auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
41
42 const float *scales[2];
43 ASSIGN_ARG_SCALE_VALUE(scales[0], DNNL_ARG_SRC_0);
44 ASSIGN_ARG_SCALE_VALUE(scales[1], DNNL_ARG_SRC_1);
45
46 const memory_desc_wrapper src0_d(pd()->src_md(0));
47 const memory_desc_wrapper src1_d(pd()->src_md(1));
48 const memory_desc_wrapper dst_d(pd()->dst_md());
49
50 const auto src0_dt = src0_d.data_type();
51 const auto src1_dt = src1_d.data_type();
52 const auto dst_dt = dst_d.data_type();
53
54 const auto alg = pd()->desc()->alg_kind;
55
56 const auto nelems = dst_d.nelems();
57 const auto ndims = pd()->ndims();
58 const auto has_postops = pd()->attr()->post_ops_.len() != 0;
59 const auto is_inplace
60 = static_cast<const void *>(src0) == static_cast<void *>(dst);
61 bool has_padding = false;
62 for (int i = 0; i < dst_d.ndims(); i++)
63 if (dst_d.dims()[i] != dst_d.padded_dims()[i]) {
64 has_padding = true;
65 break;
66 }
67
68 if (has_padding && !is_inplace) {
69 if (has_postops || !dst_d.is_dense(true)) {
70 // Use zero-padding implementation as we cannot memset over
71 // populated dst memory or submemories.
72 ctx.zero_pad_output(DNNL_ARG_TO);
73 } else {
74 const auto res = std::div(static_cast<int>(dst_d.size()), PAGE_4K);
75 if (!res.quot)
76 std::memset(dst, 0, res.rem);
77 else
78 parallel_nd(res.quot, [&](dim_t i) {
79 const auto tail = (i + 1 == res.quot) ? res.rem : 0;
80 const auto ptr = reinterpret_cast<unsigned char *>(dst)
81 + i * PAGE_4K;
82 std::memset(ptr, 0, PAGE_4K + tail);
83 });
84 }
85 }
86
87 parallel_nd(nelems, [&](dim_t i) {
88 dims_t dims_src0, dims_src1; // decomposition for physical offsets
89 utils::l_dims_by_l_offset(dims_src0, i, dst_d.dims(), ndims);
90 utils::l_dims_by_l_offset(dims_src1, i, dst_d.dims(), ndims);
91 auto off_C = dst_d.off_v(dims_src0);
92
93 int mask_src0
94 = utils::get_dims_mask(dst_d.dims(), src0_d.dims(), ndims);
95 utils::apply_mask_on_dims(dims_src0, ndims, mask_src0);
96 const auto off_A = src0_d.off_v(dims_src0);
97 int mask_src1
98 = utils::get_dims_mask(dst_d.dims(), src1_d.dims(), ndims);
99 utils::apply_mask_on_dims(dims_src1, ndims, mask_src1);
100 const auto off_B = src1_d.off_v(dims_src1);
101
102 float x_f = io::load_float_value(src0_dt, src0, off_A);
103 float y_f = io::load_float_value(src1_dt, src1, off_B);
104 float dst_f = io::load_float_value(dst_dt, dst, off_C);
105
106 x_f *= scales[0][0];
107 y_f *= scales[1][0];
108
109 float acc = compute_binary_scalar(alg, x_f, y_f);
110
111 if (has_postops) {
112 ref_post_ops_t::args_t args;
113 args.dst_val = dst_f;
114 args.ctx = &ctx;
115 args.l_offset = i;
116 args.dst_md = pd()->dst_md();
117 ref_post_ops->execute(acc, args);
118 }
119
120 io::store_float_value(dst_dt, acc, dst, off_C);
121 });
122
123 return status::success;
124}
125
126} // namespace cpu
127} // namespace impl
128} // namespace dnnl
129
130// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
131