1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <float.h> |
19 | #include <math.h> |
20 | |
21 | #include "common/c_types_map.hpp" |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/math_utils.hpp" |
24 | #include "common/nstl.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | |
27 | #include "cpu/cpu_primitive.hpp" |
28 | #include "cpu/ref_io_helper.hpp" |
29 | #include "cpu/simple_q10n.hpp" |
30 | |
31 | #include "cpu/ref_binary.hpp" |
32 | |
33 | namespace dnnl { |
34 | namespace impl { |
35 | namespace cpu { |
36 | |
37 | status_t ref_binary_t::execute_ref(const exec_ctx_t &ctx) const { |
38 | const auto src0 = CTX_IN_MEM(const void *, DNNL_ARG_SRC_0); |
39 | const auto src1 = CTX_IN_MEM(const void *, DNNL_ARG_SRC_1); |
40 | auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST); |
41 | |
42 | const float *scales[2]; |
43 | ASSIGN_ARG_SCALE_VALUE(scales[0], DNNL_ARG_SRC_0); |
44 | ASSIGN_ARG_SCALE_VALUE(scales[1], DNNL_ARG_SRC_1); |
45 | |
46 | const memory_desc_wrapper src0_d(pd()->src_md(0)); |
47 | const memory_desc_wrapper src1_d(pd()->src_md(1)); |
48 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
49 | |
50 | const auto src0_dt = src0_d.data_type(); |
51 | const auto src1_dt = src1_d.data_type(); |
52 | const auto dst_dt = dst_d.data_type(); |
53 | |
54 | const auto alg = pd()->desc()->alg_kind; |
55 | |
56 | const auto nelems = dst_d.nelems(); |
57 | const auto ndims = pd()->ndims(); |
58 | const auto has_postops = pd()->attr()->post_ops_.len() != 0; |
59 | const auto is_inplace |
60 | = static_cast<const void *>(src0) == static_cast<void *>(dst); |
61 | bool has_padding = false; |
62 | for (int i = 0; i < dst_d.ndims(); i++) |
63 | if (dst_d.dims()[i] != dst_d.padded_dims()[i]) { |
64 | has_padding = true; |
65 | break; |
66 | } |
67 | |
68 | if (has_padding && !is_inplace) { |
69 | if (has_postops || !dst_d.is_dense(true)) { |
70 | // Use zero-padding implementation as we cannot memset over |
71 | // populated dst memory or submemories. |
72 | ctx.zero_pad_output(DNNL_ARG_TO); |
73 | } else { |
74 | const auto res = std::div(static_cast<int>(dst_d.size()), PAGE_4K); |
75 | if (!res.quot) |
76 | std::memset(dst, 0, res.rem); |
77 | else |
78 | parallel_nd(res.quot, [&](dim_t i) { |
79 | const auto tail = (i + 1 == res.quot) ? res.rem : 0; |
80 | const auto ptr = reinterpret_cast<unsigned char *>(dst) |
81 | + i * PAGE_4K; |
82 | std::memset(ptr, 0, PAGE_4K + tail); |
83 | }); |
84 | } |
85 | } |
86 | |
87 | parallel_nd(nelems, [&](dim_t i) { |
88 | dims_t dims_src0, dims_src1; // decomposition for physical offsets |
89 | utils::l_dims_by_l_offset(dims_src0, i, dst_d.dims(), ndims); |
90 | utils::l_dims_by_l_offset(dims_src1, i, dst_d.dims(), ndims); |
91 | auto off_C = dst_d.off_v(dims_src0); |
92 | |
93 | int mask_src0 |
94 | = utils::get_dims_mask(dst_d.dims(), src0_d.dims(), ndims); |
95 | utils::apply_mask_on_dims(dims_src0, ndims, mask_src0); |
96 | const auto off_A = src0_d.off_v(dims_src0); |
97 | int mask_src1 |
98 | = utils::get_dims_mask(dst_d.dims(), src1_d.dims(), ndims); |
99 | utils::apply_mask_on_dims(dims_src1, ndims, mask_src1); |
100 | const auto off_B = src1_d.off_v(dims_src1); |
101 | |
102 | float x_f = io::load_float_value(src0_dt, src0, off_A); |
103 | float y_f = io::load_float_value(src1_dt, src1, off_B); |
104 | float dst_f = io::load_float_value(dst_dt, dst, off_C); |
105 | |
106 | x_f *= scales[0][0]; |
107 | y_f *= scales[1][0]; |
108 | |
109 | float acc = compute_binary_scalar(alg, x_f, y_f); |
110 | |
111 | if (has_postops) { |
112 | ref_post_ops_t::args_t args; |
113 | args.dst_val = dst_f; |
114 | args.ctx = &ctx; |
115 | args.l_offset = i; |
116 | args.dst_md = pd()->dst_md(); |
117 | ref_post_ops->execute(acc, args); |
118 | } |
119 | |
120 | io::store_float_value(dst_dt, acc, dst, off_C); |
121 | }); |
122 | |
123 | return status::success; |
124 | } |
125 | |
126 | } // namespace cpu |
127 | } // namespace impl |
128 | } // namespace dnnl |
129 | |
130 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
131 | |