1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <algorithm> |
18 | #include <cmath> |
19 | |
20 | #include "common/dnnl_thread.hpp" |
21 | #include "common/memory_desc_wrapper.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | #include "cpu/x64/cpu_isa_traits.hpp" |
24 | #include "cpu/x64/prelu/jit_prelu_forward.hpp" |
25 | #include "cpu/x64/prelu/jit_prelu_utils.hpp" |
26 | #include "cpu/x64/prelu/jit_uni_prelu_forward_kernel.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace x64 { |
32 | |
33 | status_t jit_prelu_fwd_t::pd_t::init(engine_t *engine) { |
34 | const memory_desc_wrapper src_d {src_md(0)}; |
35 | const memory_desc_wrapper weights_d {weights_md(0)}; |
36 | const memory_desc_wrapper dst_d {dst_md(0)}; |
37 | |
38 | const bool ok = is_fwd() |
39 | && prelu::dt_supported({src_d.data_type(), weights_d.data_type(), |
40 | dst_d.data_type()}) |
41 | && set_default_formats() && bcast_supported(src_d, weights_d, dst_d) |
42 | && !has_zero_dim_memory() && src_d.is_dense(true) |
43 | && weights_d.is_dense(true) && attr()->has_default_values() |
44 | && utils::one_of(prelu::get_supported_isa(), avx512_core_fp16, |
45 | avx512_core_bf16, avx512_core, avx2, avx, sse41) |
46 | && dst_d == src_d; |
47 | |
48 | return ok ? status::success : status::unimplemented; |
49 | } |
50 | |
51 | bool jit_prelu_fwd_t::pd_t::bcast_supported(const memory_desc_wrapper &src_d, |
52 | const memory_desc_wrapper &weights_d, |
53 | const memory_desc_wrapper &dst_d) const { |
54 | |
55 | const auto bcast = prelu::get_bcast_type(src_d, weights_d); |
56 | if (bcast == prelu::bcast::full) |
57 | return true; |
58 | else if (bcast == prelu::bcast::unsupported) |
59 | return false; |
60 | else if (bcast == prelu::bcast::per_oc_blocked) { |
61 | const int simd_w = prelu::get_simd_w( |
62 | {src_d.data_type(), weights_d.data_type(), dst_d.data_type()}); |
63 | |
64 | const auto check_block_consistency |
65 | = [&](const memory_desc_wrapper &mdw) { |
66 | const auto &bd = mdw.blocking_desc(); |
67 | |
68 | return bd.inner_nblks == 1 && bd.inner_blks[0] == simd_w |
69 | && bd.inner_idxs[0] == 1; |
70 | }; |
71 | |
72 | return check_block_consistency(src_d) |
73 | && check_block_consistency(weights_d); |
74 | } else { |
75 | const auto &src_strides = src_d.blocking_desc().strides; |
76 | const auto &weights_strides = weights_d.blocking_desc().strides; |
77 | // C should be on second position in tag (example nchw or ncw) or on |
78 | // last postion (nhwc) |
79 | return src_strides[0] >= src_strides[1] |
80 | && IMPLICATION( |
81 | src_strides[1] > 1, src_strides[1] >= src_strides[2]) |
82 | && weights_strides[0] >= weights_strides[1]; |
83 | } |
84 | |
85 | return true; |
86 | } |
87 | |
88 | const jit_prelu_fwd_t::pd_t *jit_prelu_fwd_t::pd() const { |
89 | return static_cast<const pd_t *>(primitive_t::pd().get()); |
90 | } |
91 | |
92 | jit_prelu_fwd_t::jit_prelu_fwd_t(const pd_t *apd) : primitive_t(apd) {} |
93 | jit_prelu_fwd_t::~jit_prelu_fwd_t() = default; |
94 | |
95 | status_t jit_prelu_fwd_t::init(engine_t *engine) { |
96 | CHECK(safe_ptr_assign(kernel_, jit_prelu_forward_kernel_t::create(pd()))); |
97 | return kernel_->create_kernel(); |
98 | } |
99 | |
100 | status_t jit_prelu_fwd_t::execute(const exec_ctx_t &ctx) const { |
101 | using byte = unsigned char; |
102 | const byte *const src = CTX_IN_MEM(const byte *, DNNL_ARG_SRC); |
103 | const byte *const weights = CTX_IN_MEM(const byte *, DNNL_ARG_WEIGHTS); |
104 | byte *const dst = CTX_OUT_MEM(byte *, DNNL_ARG_DST); |
105 | const memory_desc_wrapper src_d {pd()->src_md(0)}; |
106 | |
107 | const auto src_dt_size = types::data_type_size(src_d.data_type()); |
108 | const auto weights_dt_size |
109 | = types::data_type_size(pd()->weights_md(0)->data_type); |
110 | const auto dst_dt_size = types::data_type_size(pd()->dst_md(0)->data_type); |
111 | |
112 | const auto kernel = kernel_.get(); |
113 | const auto bcast = kernel->get_bcast(); |
114 | const auto ndims = src_d.ndims(); |
115 | const dim_t MB = pd()->N(); |
116 | const dim_t C = pd()->C(); |
117 | const dim_t D = pd()->D(); |
118 | const dim_t H = pd()->H(); |
119 | const dim_t W = pd()->W(); |
120 | const dim_t SP = D * H * W; |
121 | |
122 | if (bcast == prelu::bcast::full) { |
123 | const auto nelems = src_d.nelems(true); |
124 | const auto simd_w = kernel->simd_w(); |
125 | const auto res = std::div(nelems, simd_w); |
126 | const auto &nelems_simd = res.quot; |
127 | const auto &nelems_tail = res.rem; |
128 | const auto nelems_parallel = nelems_simd + (nelems_tail ? 1 : 0); |
129 | |
130 | parallel(0, [&](const int ithr, const int nthr) { |
131 | dim_t start = 0, end = 0; |
132 | balance211(nelems_parallel, nthr, ithr, start, end); |
133 | if (start >= end) return; |
134 | |
135 | const bool ithr_process_tail |
136 | = nelems_tail && end == nelems_parallel; |
137 | const auto n_simd_size = (end - start - ithr_process_tail) * simd_w; |
138 | const auto offset = start * simd_w; |
139 | |
140 | jit_prelu_forward_kernel_t::call_params_t params; |
141 | |
142 | params.compute_data_size |
143 | = (n_simd_size + (nelems_tail ? nelems_tail : 0)); |
144 | params.src = src + (offset * src_dt_size); |
145 | params.weights = weights + (offset * weights_dt_size); |
146 | params.dst = dst + (offset * dst_dt_size); |
147 | |
148 | (*kernel)(¶ms); |
149 | }); |
150 | } else { |
151 | |
152 | const dim_t nelems_single_mb |
153 | = utils::array_product(src_d.padded_dims() + 1, ndims - 1); |
154 | |
155 | if (bcast == prelu::bcast::per_oc_n_spatial_c) { |
156 | parallel_nd(MB, SP, [&](dim_t mb, dim_t sp) { |
157 | const auto offset = (mb * nelems_single_mb + sp * C); |
158 | jit_prelu_forward_kernel_t::call_params_t params; |
159 | params.compute_data_size = C; |
160 | params.src = src + offset * src_dt_size; |
161 | params.weights = weights; |
162 | params.dst = dst + offset * dst_dt_size; |
163 | (*kernel)(¶ms); |
164 | }); |
165 | } else if (bcast == prelu::bcast::per_oc_n_c_spatial) { |
166 | parallel_nd(MB, C, [&](dim_t mb, dim_t c) { |
167 | jit_prelu_forward_kernel_t::call_params_t params; |
168 | const auto offset = (mb * nelems_single_mb + c * SP); |
169 | params.compute_data_size = SP; |
170 | params.src = src + offset * src_dt_size; |
171 | params.weights = weights + c * weights_dt_size; |
172 | params.dst = dst + offset * dst_dt_size; |
173 | (*kernel)(¶ms); |
174 | }); |
175 | } else if (bcast == prelu::bcast::per_oc_blocked) { |
176 | const auto simd_w = kernel->simd_w(); |
177 | const dim_t C_blocks = std::ceil(static_cast<float>(C) / simd_w); |
178 | |
179 | parallel_nd(MB, C_blocks, [&](dim_t mb, dim_t c_blk) { |
180 | jit_prelu_forward_kernel_t::call_params_t params; |
181 | params.compute_data_size = SP * simd_w; |
182 | const dim_t offset |
183 | = (mb * nelems_single_mb + c_blk * SP * simd_w); |
184 | |
185 | params.src = src + offset * src_dt_size; |
186 | params.weights = weights + c_blk * simd_w * weights_dt_size; |
187 | params.dst = dst + offset * dst_dt_size; |
188 | (*kernel)(¶ms); |
189 | }); |
190 | } |
191 | } |
192 | return status::success; |
193 | } |
194 | |
195 | } // namespace x64 |
196 | } // namespace cpu |
197 | } // namespace impl |
198 | } // namespace dnnl |
199 | |