1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_WINO_REORDER_HPP |
18 | #define CPU_X64_WINO_REORDER_HPP |
19 | |
20 | #include "common/dnnl_thread.hpp" |
21 | #include "common/primitive.hpp" |
22 | #include "common/primitive_desc.hpp" |
23 | |
24 | #include "cpu/cpu_primitive.hpp" |
25 | #include "cpu/reorder/cpu_reorder_pd.hpp" |
26 | #include "cpu/simple_q10n.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace x64 { |
32 | |
33 | template <data_type_t type_i, data_type_t type_o> |
34 | struct wino_reorder_t : public primitive_t { |
35 | struct pd_t : public cpu_reorder_pd_t { |
36 | using cpu_reorder_pd_t::cpu_reorder_pd_t; |
37 | |
38 | DECLARE_COMMON_PD_T("wino_reorder" , wino_reorder_t); |
39 | |
40 | status_t init( |
41 | engine_t *engine, engine_t *src_engine, engine_t *dst_engine) { |
42 | status_t status |
43 | = cpu_reorder_pd_t::init(engine, src_engine, dst_engine); |
44 | if (status != status::success) return status; |
45 | |
46 | bool ok = attr()->has_default_values( |
47 | primitive_attr_t::skip_mask_t::oscale_runtime |
48 | | primitive_attr_t::skip_mask_t::post_ops); |
49 | if (!ok) return status::unimplemented; |
50 | |
51 | init_scratchpad(); |
52 | |
53 | return status::success; |
54 | } |
55 | |
56 | int nthr_; // To not exceed the limit in execute used for set up. |
57 | |
58 | private: |
59 | static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, |
60 | const primitive_attr_t *attr, engine_t *src_engine, |
61 | const memory_desc_t *src_md, engine_t *dst_engine, |
62 | const memory_desc_t *dst_md) { |
63 | const memory_desc_wrapper id(src_md), od(dst_md); |
64 | bool args_ok = true && id.data_type() == type_i |
65 | && od.data_type() == type_o |
66 | && od.format_kind() == format_kind::wino |
67 | && utils::one_of(od.wino_desc().wino_format, |
68 | wino_memory_format_t::wino_wei_aaOio, |
69 | wino_memory_format_t::wino_wei_aaOBiOo, |
70 | wino_memory_format_t::wino_wei_OBaaIBOIio) |
71 | && (id.matches_tag(utils::pick(id.ndims() - 4, |
72 | format_tag::oihw, format_tag::goihw)) |
73 | || id.matches_tag(utils::pick(id.ndims() - 4, |
74 | format_tag::hwio, format_tag::hwigo))); |
75 | if (!args_ok) return status::invalid_arguments; |
76 | |
77 | auto _pd = new pd_t(attr, src_engine->kind(), src_md, |
78 | dst_engine->kind(), dst_md); |
79 | if (_pd == nullptr) return status::out_of_memory; |
80 | if (_pd->init(engine, src_engine, dst_engine) != status::success) { |
81 | delete _pd; |
82 | return status::unimplemented; |
83 | } |
84 | _pd->init_scratchpad_md(); |
85 | return safe_ptr_assign(*reorder_pd, _pd); |
86 | } |
87 | |
88 | void init_scratchpad() { |
89 | const auto &wino_desc = memory_desc_wrapper(dst_md()).wino_desc(); |
90 | const int nb_oc = wino_desc.oc / wino_desc.oc_block; |
91 | const int work_amount = nb_oc * wino_desc.ic; |
92 | nthr_ = nstl::min(dnnl_get_max_threads(), work_amount); |
93 | const size_t transform_space_size = static_cast<size_t>(wino_desc.r) |
94 | * wino_desc.alpha * wino_desc.oc_block * nthr_; |
95 | const size_t plain_size = static_cast<size_t>(wino_desc.alpha) |
96 | * wino_desc.alpha * wino_desc.oc * wino_desc.ic; |
97 | |
98 | using namespace memory_tracking::names; |
99 | auto scratchpad = scratchpad_registry().registrar(); |
100 | scratchpad.template book<in_data_t>( |
101 | key_reorder_wino_transform_space, transform_space_size); |
102 | scratchpad.template book<out_data_t>( |
103 | key_reorder_wino_plain, plain_size); |
104 | } |
105 | friend dnnl::impl::impl_list_item_t; |
106 | }; |
107 | |
108 | wino_reorder_t(const pd_t *apd) : primitive_t(apd) {} |
109 | |
110 | status_t init(engine_t *engine) override { |
111 | const memory_desc_wrapper src_d(pd()->src_md()); |
112 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
113 | |
114 | r_ = dst_d.wino_desc().r; |
115 | w_alpha_ = dst_d.wino_desc().alpha; |
116 | wino_format_ = dst_d.wino_desc().wino_format; |
117 | |
118 | const auto &in_dims = src_d.dims(); |
119 | int groups; |
120 | int groups_offset; |
121 | if (src_d.ndims() == 5) { |
122 | groups = in_dims[0]; |
123 | groups_offset = 1; |
124 | } else { |
125 | groups = 1; |
126 | groups_offset = 0; |
127 | } |
128 | assert(groups == 1); // groups are not supported now |
129 | MAYBE_UNUSED(groups); |
130 | |
131 | or_oc_ = in_dims[0 + groups_offset]; |
132 | or_ic_ = in_dims[1 + groups_offset]; |
133 | kh_ = in_dims[2 + groups_offset]; |
134 | kw_ = in_dims[3 + groups_offset]; |
135 | |
136 | oc_ = dst_d.wino_desc().oc; |
137 | ic_ = dst_d.wino_desc().ic; |
138 | oc_block_ = dst_d.wino_desc().oc_block; |
139 | ic_block_ = dst_d.wino_desc().ic_block; |
140 | assert(oc_ % oc_block_ == 0 && ic_ % ic_block_ == 0); |
141 | nb_oc_ = oc_ / oc_block_; |
142 | nb_ic_ = ic_ / ic_block_; |
143 | ic2_block_ = 1; |
144 | if (wino_format_ == wino_memory_format_t::wino_wei_OBaaIBOIio) |
145 | ic2_block_ = dst_d.wino_desc().ic2_block; |
146 | oc2_block_ = dst_d.wino_desc().oc2_block; |
147 | assert(nb_ic_ % ic2_block_ == 0 && nb_oc_ % oc2_block_ == 0); |
148 | |
149 | adj_scale_ = dst_d.wino_desc().adj_scale; |
150 | |
151 | size_wino_wei_ = w_alpha_ * w_alpha_ * oc_ * ic_; |
152 | work_amount_ = ic_ * nb_oc_; |
153 | size_wspace_thr_ = r_ * w_alpha_ * oc_block_; |
154 | |
155 | return status::success; |
156 | } |
157 | |
158 | private: |
159 | typedef typename prec_traits<type_i>::type in_data_t; |
160 | typedef typename prec_traits<type_o>::type out_data_t; |
161 | const int unsign_val_in_wino_domain_ = 5; |
162 | |
163 | void transform(out_data_t *__restrict tmp_wei, |
164 | const in_data_t *__restrict input, in_data_t *__restrict wspace, |
165 | const float *__restrict oscales) const { |
166 | const memory_desc_wrapper src_d(pd()->src_md()); |
167 | |
168 | /* transform weights to winograd domain */ |
169 | const float G_2x2_3x3[4][3] = {{1.0, 0.0, 0.0}, {0.5, 0.5, 0.5}, |
170 | {0.5, -0.5, 0.5}, {0.0, 0.0, 1.0}}; |
171 | |
172 | const float G_4x4_3x3[6][3] = {{1.13777777777778f, 0.f, 0.f}, |
173 | {-0.688403361344538f, -0.430252100840336f, -0.26890756302521f}, |
174 | {-0.688403361344538f, 0.430252100840336f, -0.26890756302521f}, |
175 | {0.119514472455649f, 0.179271708683473f, 0.26890756302521f}, |
176 | {0.119514472455649f, -0.179271708683473f, 0.26890756302521f}, |
177 | {0.f, 0.f, 1.f}}; |
178 | |
179 | float *__restrict g; |
180 | if (utils::one_of(wino_format_, wino_memory_format_t::wino_wei_aaOio, |
181 | wino_memory_format_t::wino_wei_aaOBiOo)) |
182 | g = (float *)G_2x2_3x3; |
183 | else if (wino_format_ == wino_memory_format_t::wino_wei_OBaaIBOIio) |
184 | g = (float *)G_4x4_3x3; |
185 | else { |
186 | assert(!"Unknown winograd weights target layout" ); |
187 | return; |
188 | } |
189 | |
190 | const bool has_oihw_format = false |
191 | || src_d.matches_tag(format_tag::oihw) |
192 | || src_d.matches_tag(format_tag::goihw); |
193 | |
194 | const int Z = oc_ * ic_; |
195 | const int or_ioc_ = or_ic_ * or_oc_; |
196 | assert(r_ == kh_ && r_ == kw_); |
197 | const int nthr = pd()->nthr_; |
198 | |
199 | parallel_nd_ext(nthr, ic_, nb_oc_, |
200 | [&](int ithr, int nthr, dim_t iic, dim_t ob) { |
201 | if (ithr >= work_amount_) return; |
202 | |
203 | const in_data_t *__restrict _inp = has_oihw_format |
204 | ? input |
205 | + (ob * oc_block_ * or_ic_ + iic) * kh_ |
206 | * kw_ |
207 | : input + iic * or_oc_ + ob * oc_block_; |
208 | out_data_t *__restrict _out |
209 | = tmp_wei + (iic * nb_oc_ + ob) * oc_block_; |
210 | |
211 | in_data_t *__restrict wspace_thr |
212 | = wspace + ithr * size_wspace_thr_; |
213 | |
214 | std::memset(wspace_thr, 0.f, |
215 | size_wspace_thr_ * sizeof(in_data_t)); |
216 | |
217 | if (has_oihw_format) { |
218 | for_(int ih = 0; ih < r_; ++ih) |
219 | for_(int j = 0; j < w_alpha_; ++j) |
220 | for (int ioc = 0; ioc < oc_block_; ++ioc) { |
221 | PRAGMA_OMP_SIMD() |
222 | for (int iw = 0; iw < r_; ++iw) { |
223 | const int inp_oc = ob * oc_block_ + ioc; |
224 | const int inp_ic = iic; |
225 | in_data_t inp_v |
226 | = (inp_ic < or_ic_ && inp_oc < or_oc_) |
227 | ? _inp[ioc * or_ic_ * kh_ * kw_ |
228 | + ih * kw_ + iw] |
229 | : 0.f; |
230 | wspace_thr[(ih * w_alpha_ + j) * oc_block_ |
231 | + ioc] |
232 | += inp_v * g[j * r_ + iw]; |
233 | } |
234 | } |
235 | } else { // hwio format case |
236 | for_(int ih = 0; ih < r_; ++ih) |
237 | for_(int j = 0; j < w_alpha_; ++j) |
238 | for (int iw = 0; iw < kw_; ++iw) { |
239 | const float g_multiplier = g[j * r_ + iw]; |
240 | const in_data_t *__restrict inp_base |
241 | = _inp + or_ioc_ * (iw + ih * kw_); |
242 | in_data_t *__restrict wspace_base = wspace_thr |
243 | + (ih * w_alpha_ + j) * oc_block_; |
244 | |
245 | PRAGMA_OMP_SIMD() |
246 | for (int ioc = 0; ioc < oc_block_; ++ioc) { |
247 | const int inp_oc = ob * oc_block_ + ioc; |
248 | const int inp_ic = iic; |
249 | in_data_t inp_v |
250 | = (inp_ic < or_ic_ && inp_oc < or_oc_) |
251 | ? inp_base[ioc] |
252 | : 0.f; |
253 | |
254 | wspace_base[ioc] += inp_v * g_multiplier; |
255 | } |
256 | } |
257 | } |
258 | |
259 | for_(int i = 0; i < w_alpha_; ++i) |
260 | for_(int j = 0; j < w_alpha_; ++j) |
261 | for (int ioc = 0; ioc < oc_block_; ++ioc) { |
262 | float res = 0; |
263 | PRAGMA_OMP_SIMD(reduction(+ : res)) |
264 | for (int k = 0; k < r_; ++k) |
265 | res += g[i * r_ + k] |
266 | * wspace_thr[(k * w_alpha_ + j) * oc_block_ |
267 | + ioc]; |
268 | _out[(i * w_alpha_ + j) * Z + ioc] |
269 | = static_cast<out_data_t>(res); |
270 | } |
271 | }); |
272 | } |
273 | |
274 | void reorder_to_aaOio(out_data_t *__restrict output, |
275 | const out_data_t *__restrict tmp_wei) const { |
276 | parallel_nd(w_alpha_, w_alpha_, nb_oc_, |
277 | [&](dim_t u_h, dim_t u_w, dim_t ob) { |
278 | for_(int ib = 0; ib < nb_ic_; ib++) |
279 | for_(int i = 0; i < ic_block_; i++) |
280 | for (int o = 0; o < oc_block_; o++) { |
281 | const int src_offset = u_h * w_alpha_ * ic_ * oc_ |
282 | + u_w * ic_ * oc_ + (ib * ic_block_ + i) * oc_ |
283 | + (ob * oc_block_ + o); |
284 | |
285 | const int dst_offset = u_h * w_alpha_ * nb_oc_ * nb_ic_ |
286 | * ic_block_ * oc_block_ |
287 | + u_w * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ |
288 | + ob * nb_ic_ * ic_block_ * oc_block_ |
289 | + ib * ic_block_ * oc_block_ + i * oc_block_ |
290 | + o; |
291 | output[dst_offset] = tmp_wei[src_offset]; |
292 | } |
293 | }); |
294 | } |
295 | |
296 | void reorder_to_aaOBiOo(out_data_t *__restrict output, |
297 | const out_data_t *__restrict tmp_wei) const { |
298 | const int oc_chunks = nb_oc_ / oc2_block_; |
299 | parallel_nd(w_alpha_, w_alpha_, oc_chunks, |
300 | [&](dim_t u_h, dim_t u_w, dim_t occ) { |
301 | for (int ib = 0; ib < nb_ic_; ib++) { |
302 | out_data_t *__restrict wei_ptr = output |
303 | + (((u_h * w_alpha_ + u_w) * oc_chunks + occ) |
304 | * nb_ic_ |
305 | + ib) |
306 | * oc2_block_ * ic_block_ * oc_block_; |
307 | int wei_offset = 0; |
308 | for_(int i = 0; i < ic_block_; i++) |
309 | for (int ob2 = 0; ob2 < oc2_block_; ob2++) { |
310 | for (int o = 0; o < oc_block_; o++) { |
311 | const int icp = ib * ic_block_ + i; |
312 | const int ocp = occ * oc2_block_ * oc_block_ |
313 | + ob2 * oc_block_ + o; |
314 | |
315 | const int src_offset |
316 | = u_h * w_alpha_ * ic_ * oc_ |
317 | + u_w * ic_ * oc_ + icp * oc_ + ocp; |
318 | wei_ptr[wei_offset + o] = tmp_wei[src_offset]; |
319 | } |
320 | wei_offset += oc_block_; |
321 | } |
322 | } |
323 | }); |
324 | } |
325 | |
326 | void reorder_to_OBaaIBOIio(out_data_t *__restrict output, |
327 | const out_data_t *__restrict tmp_wei) const { |
328 | const int ic_chunks = nb_ic_ / ic2_block_; |
329 | const int oc_chunks = nb_oc_ / oc2_block_; |
330 | parallel_nd(oc_chunks, w_alpha_, w_alpha_, |
331 | [&](dim_t occ, dim_t u_h, dim_t u_w) { |
332 | for_(int icc = 0; icc < ic_chunks; icc++) |
333 | for (int ob = 0; ob < oc2_block_; ob++) { |
334 | const int ocp = (occ * oc2_block_ + ob) * oc_block_; |
335 | for_(int ib = 0; ib < ic2_block_; ib++) |
336 | for (int i = 0; i < ic_block_; i++) { |
337 | const int icp |
338 | = (icc * ic2_block_ + ib) * ic_block_ + i; |
339 | |
340 | const int src_offset = u_h * w_alpha_ * ic_ * oc_ |
341 | + u_w * ic_ * oc_ + icp * oc_ + ocp; |
342 | const int wei_offset |
343 | = ((((((occ * w_alpha_ + u_h) * w_alpha_ |
344 | + u_w) * ic_chunks |
345 | + icc) * oc2_block_ |
346 | + ob) * ic2_block_ |
347 | + ib) * ic_block_ |
348 | + i) |
349 | * oc_block_; |
350 | for (int o = 0; o < oc_block_; o++) |
351 | output[wei_offset + o] |
352 | = tmp_wei[src_offset + o]; |
353 | } |
354 | } |
355 | }); |
356 | } |
357 | |
358 | status_t execute(const exec_ctx_t &ctx) const override { |
359 | auto input = CTX_IN_MEM(const in_data_t *, DNNL_ARG_FROM); |
360 | auto output = CTX_OUT_MEM(out_data_t *, DNNL_ARG_TO); |
361 | |
362 | auto wspace = (in_data_t * __restrict) ctx.get_scratchpad_grantor() |
363 | .template get<void>(memory_tracking::names:: |
364 | key_reorder_wino_transform_space); |
365 | auto tmp_wei = (out_data_t * __restrict) ctx.get_scratchpad_grantor() |
366 | .template get<void>(memory_tracking::names:: |
367 | key_reorder_wino_plain); |
368 | |
369 | DEFINE_SCALES_BUFFER(oscales); |
370 | |
371 | transform(tmp_wei, input, wspace, oscales); |
372 | |
373 | /* reorder to winograd domain */ |
374 | switch (wino_format_) { |
375 | case wino_memory_format_t::wino_wei_aaOio: |
376 | reorder_to_aaOio(output, tmp_wei); |
377 | break; |
378 | case wino_memory_format_t::wino_wei_aaOBiOo: |
379 | reorder_to_aaOBiOo(output, tmp_wei); |
380 | break; |
381 | case wino_memory_format_t::wino_wei_OBaaIBOIio: |
382 | reorder_to_OBaaIBOIio(output, tmp_wei); |
383 | break; |
384 | default: assert(!"Unknown wino format" ); break; |
385 | } |
386 | |
387 | return status::success; |
388 | } |
389 | |
390 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
391 | int r_, w_alpha_; |
392 | dim_t ic_, oc_, or_ic_, or_oc_, kh_, kw_; |
393 | dim_t oc_block_, ic_block_, oc2_block_, ic2_block_; |
394 | float adj_scale_; |
395 | dim_t nb_oc_, nb_ic_; |
396 | wino_memory_format_t wino_format_; |
397 | int size_wino_wei_; |
398 | int size_wspace_thr_; |
399 | int work_amount_; |
400 | }; |
401 | |
402 | } // namespace x64 |
403 | } // namespace cpu |
404 | } // namespace impl |
405 | } // namespace dnnl |
406 | |
407 | #endif |
408 | |