1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_WINO_REORDER_HPP
18#define CPU_X64_WINO_REORDER_HPP
19
20#include "common/dnnl_thread.hpp"
21#include "common/primitive.hpp"
22#include "common/primitive_desc.hpp"
23
24#include "cpu/cpu_primitive.hpp"
25#include "cpu/reorder/cpu_reorder_pd.hpp"
26#include "cpu/simple_q10n.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33template <data_type_t type_i, data_type_t type_o>
34struct wino_reorder_t : public primitive_t {
35 struct pd_t : public cpu_reorder_pd_t {
36 using cpu_reorder_pd_t::cpu_reorder_pd_t;
37
38 DECLARE_COMMON_PD_T("wino_reorder", wino_reorder_t);
39
40 status_t init(
41 engine_t *engine, engine_t *src_engine, engine_t *dst_engine) {
42 status_t status
43 = cpu_reorder_pd_t::init(engine, src_engine, dst_engine);
44 if (status != status::success) return status;
45
46 bool ok = attr()->has_default_values(
47 primitive_attr_t::skip_mask_t::oscale_runtime
48 | primitive_attr_t::skip_mask_t::post_ops);
49 if (!ok) return status::unimplemented;
50
51 init_scratchpad();
52
53 return status::success;
54 }
55
56 int nthr_; // To not exceed the limit in execute used for set up.
57
58 private:
59 static status_t create(reorder_pd_t **reorder_pd, engine_t *engine,
60 const primitive_attr_t *attr, engine_t *src_engine,
61 const memory_desc_t *src_md, engine_t *dst_engine,
62 const memory_desc_t *dst_md) {
63 const memory_desc_wrapper id(src_md), od(dst_md);
64 bool args_ok = true && id.data_type() == type_i
65 && od.data_type() == type_o
66 && od.format_kind() == format_kind::wino
67 && utils::one_of(od.wino_desc().wino_format,
68 wino_memory_format_t::wino_wei_aaOio,
69 wino_memory_format_t::wino_wei_aaOBiOo,
70 wino_memory_format_t::wino_wei_OBaaIBOIio)
71 && (id.matches_tag(utils::pick(id.ndims() - 4,
72 format_tag::oihw, format_tag::goihw))
73 || id.matches_tag(utils::pick(id.ndims() - 4,
74 format_tag::hwio, format_tag::hwigo)));
75 if (!args_ok) return status::invalid_arguments;
76
77 auto _pd = new pd_t(attr, src_engine->kind(), src_md,
78 dst_engine->kind(), dst_md);
79 if (_pd == nullptr) return status::out_of_memory;
80 if (_pd->init(engine, src_engine, dst_engine) != status::success) {
81 delete _pd;
82 return status::unimplemented;
83 }
84 _pd->init_scratchpad_md();
85 return safe_ptr_assign(*reorder_pd, _pd);
86 }
87
88 void init_scratchpad() {
89 const auto &wino_desc = memory_desc_wrapper(dst_md()).wino_desc();
90 const int nb_oc = wino_desc.oc / wino_desc.oc_block;
91 const int work_amount = nb_oc * wino_desc.ic;
92 nthr_ = nstl::min(dnnl_get_max_threads(), work_amount);
93 const size_t transform_space_size = static_cast<size_t>(wino_desc.r)
94 * wino_desc.alpha * wino_desc.oc_block * nthr_;
95 const size_t plain_size = static_cast<size_t>(wino_desc.alpha)
96 * wino_desc.alpha * wino_desc.oc * wino_desc.ic;
97
98 using namespace memory_tracking::names;
99 auto scratchpad = scratchpad_registry().registrar();
100 scratchpad.template book<in_data_t>(
101 key_reorder_wino_transform_space, transform_space_size);
102 scratchpad.template book<out_data_t>(
103 key_reorder_wino_plain, plain_size);
104 }
105 friend dnnl::impl::impl_list_item_t;
106 };
107
108 wino_reorder_t(const pd_t *apd) : primitive_t(apd) {}
109
110 status_t init(engine_t *engine) override {
111 const memory_desc_wrapper src_d(pd()->src_md());
112 const memory_desc_wrapper dst_d(pd()->dst_md());
113
114 r_ = dst_d.wino_desc().r;
115 w_alpha_ = dst_d.wino_desc().alpha;
116 wino_format_ = dst_d.wino_desc().wino_format;
117
118 const auto &in_dims = src_d.dims();
119 int groups;
120 int groups_offset;
121 if (src_d.ndims() == 5) {
122 groups = in_dims[0];
123 groups_offset = 1;
124 } else {
125 groups = 1;
126 groups_offset = 0;
127 }
128 assert(groups == 1); // groups are not supported now
129 MAYBE_UNUSED(groups);
130
131 or_oc_ = in_dims[0 + groups_offset];
132 or_ic_ = in_dims[1 + groups_offset];
133 kh_ = in_dims[2 + groups_offset];
134 kw_ = in_dims[3 + groups_offset];
135
136 oc_ = dst_d.wino_desc().oc;
137 ic_ = dst_d.wino_desc().ic;
138 oc_block_ = dst_d.wino_desc().oc_block;
139 ic_block_ = dst_d.wino_desc().ic_block;
140 assert(oc_ % oc_block_ == 0 && ic_ % ic_block_ == 0);
141 nb_oc_ = oc_ / oc_block_;
142 nb_ic_ = ic_ / ic_block_;
143 ic2_block_ = 1;
144 if (wino_format_ == wino_memory_format_t::wino_wei_OBaaIBOIio)
145 ic2_block_ = dst_d.wino_desc().ic2_block;
146 oc2_block_ = dst_d.wino_desc().oc2_block;
147 assert(nb_ic_ % ic2_block_ == 0 && nb_oc_ % oc2_block_ == 0);
148
149 adj_scale_ = dst_d.wino_desc().adj_scale;
150
151 size_wino_wei_ = w_alpha_ * w_alpha_ * oc_ * ic_;
152 work_amount_ = ic_ * nb_oc_;
153 size_wspace_thr_ = r_ * w_alpha_ * oc_block_;
154
155 return status::success;
156 }
157
158private:
159 typedef typename prec_traits<type_i>::type in_data_t;
160 typedef typename prec_traits<type_o>::type out_data_t;
161 const int unsign_val_in_wino_domain_ = 5;
162
163 void transform(out_data_t *__restrict tmp_wei,
164 const in_data_t *__restrict input, in_data_t *__restrict wspace,
165 const float *__restrict oscales) const {
166 const memory_desc_wrapper src_d(pd()->src_md());
167
168 /* transform weights to winograd domain */
169 const float G_2x2_3x3[4][3] = {{1.0, 0.0, 0.0}, {0.5, 0.5, 0.5},
170 {0.5, -0.5, 0.5}, {0.0, 0.0, 1.0}};
171
172 const float G_4x4_3x3[6][3] = {{1.13777777777778f, 0.f, 0.f},
173 {-0.688403361344538f, -0.430252100840336f, -0.26890756302521f},
174 {-0.688403361344538f, 0.430252100840336f, -0.26890756302521f},
175 {0.119514472455649f, 0.179271708683473f, 0.26890756302521f},
176 {0.119514472455649f, -0.179271708683473f, 0.26890756302521f},
177 {0.f, 0.f, 1.f}};
178
179 float *__restrict g;
180 if (utils::one_of(wino_format_, wino_memory_format_t::wino_wei_aaOio,
181 wino_memory_format_t::wino_wei_aaOBiOo))
182 g = (float *)G_2x2_3x3;
183 else if (wino_format_ == wino_memory_format_t::wino_wei_OBaaIBOIio)
184 g = (float *)G_4x4_3x3;
185 else {
186 assert(!"Unknown winograd weights target layout");
187 return;
188 }
189
190 const bool has_oihw_format = false
191 || src_d.matches_tag(format_tag::oihw)
192 || src_d.matches_tag(format_tag::goihw);
193
194 const int Z = oc_ * ic_;
195 const int or_ioc_ = or_ic_ * or_oc_;
196 assert(r_ == kh_ && r_ == kw_);
197 const int nthr = pd()->nthr_;
198
199 parallel_nd_ext(nthr, ic_, nb_oc_,
200 [&](int ithr, int nthr, dim_t iic, dim_t ob) {
201 if (ithr >= work_amount_) return;
202
203 const in_data_t *__restrict _inp = has_oihw_format
204 ? input
205 + (ob * oc_block_ * or_ic_ + iic) * kh_
206 * kw_
207 : input + iic * or_oc_ + ob * oc_block_;
208 out_data_t *__restrict _out
209 = tmp_wei + (iic * nb_oc_ + ob) * oc_block_;
210
211 in_data_t *__restrict wspace_thr
212 = wspace + ithr * size_wspace_thr_;
213
214 std::memset(wspace_thr, 0.f,
215 size_wspace_thr_ * sizeof(in_data_t));
216
217 if (has_oihw_format) {
218 for_(int ih = 0; ih < r_; ++ih)
219 for_(int j = 0; j < w_alpha_; ++j)
220 for (int ioc = 0; ioc < oc_block_; ++ioc) {
221 PRAGMA_OMP_SIMD()
222 for (int iw = 0; iw < r_; ++iw) {
223 const int inp_oc = ob * oc_block_ + ioc;
224 const int inp_ic = iic;
225 in_data_t inp_v
226 = (inp_ic < or_ic_ && inp_oc < or_oc_)
227 ? _inp[ioc * or_ic_ * kh_ * kw_
228 + ih * kw_ + iw]
229 : 0.f;
230 wspace_thr[(ih * w_alpha_ + j) * oc_block_
231 + ioc]
232 += inp_v * g[j * r_ + iw];
233 }
234 }
235 } else { // hwio format case
236 for_(int ih = 0; ih < r_; ++ih)
237 for_(int j = 0; j < w_alpha_; ++j)
238 for (int iw = 0; iw < kw_; ++iw) {
239 const float g_multiplier = g[j * r_ + iw];
240 const in_data_t *__restrict inp_base
241 = _inp + or_ioc_ * (iw + ih * kw_);
242 in_data_t *__restrict wspace_base = wspace_thr
243 + (ih * w_alpha_ + j) * oc_block_;
244
245 PRAGMA_OMP_SIMD()
246 for (int ioc = 0; ioc < oc_block_; ++ioc) {
247 const int inp_oc = ob * oc_block_ + ioc;
248 const int inp_ic = iic;
249 in_data_t inp_v
250 = (inp_ic < or_ic_ && inp_oc < or_oc_)
251 ? inp_base[ioc]
252 : 0.f;
253
254 wspace_base[ioc] += inp_v * g_multiplier;
255 }
256 }
257 }
258
259 for_(int i = 0; i < w_alpha_; ++i)
260 for_(int j = 0; j < w_alpha_; ++j)
261 for (int ioc = 0; ioc < oc_block_; ++ioc) {
262 float res = 0;
263 PRAGMA_OMP_SIMD(reduction(+ : res))
264 for (int k = 0; k < r_; ++k)
265 res += g[i * r_ + k]
266 * wspace_thr[(k * w_alpha_ + j) * oc_block_
267 + ioc];
268 _out[(i * w_alpha_ + j) * Z + ioc]
269 = static_cast<out_data_t>(res);
270 }
271 });
272 }
273
274 void reorder_to_aaOio(out_data_t *__restrict output,
275 const out_data_t *__restrict tmp_wei) const {
276 parallel_nd(w_alpha_, w_alpha_, nb_oc_,
277 [&](dim_t u_h, dim_t u_w, dim_t ob) {
278 for_(int ib = 0; ib < nb_ic_; ib++)
279 for_(int i = 0; i < ic_block_; i++)
280 for (int o = 0; o < oc_block_; o++) {
281 const int src_offset = u_h * w_alpha_ * ic_ * oc_
282 + u_w * ic_ * oc_ + (ib * ic_block_ + i) * oc_
283 + (ob * oc_block_ + o);
284
285 const int dst_offset = u_h * w_alpha_ * nb_oc_ * nb_ic_
286 * ic_block_ * oc_block_
287 + u_w * nb_oc_ * nb_ic_ * ic_block_ * oc_block_
288 + ob * nb_ic_ * ic_block_ * oc_block_
289 + ib * ic_block_ * oc_block_ + i * oc_block_
290 + o;
291 output[dst_offset] = tmp_wei[src_offset];
292 }
293 });
294 }
295
296 void reorder_to_aaOBiOo(out_data_t *__restrict output,
297 const out_data_t *__restrict tmp_wei) const {
298 const int oc_chunks = nb_oc_ / oc2_block_;
299 parallel_nd(w_alpha_, w_alpha_, oc_chunks,
300 [&](dim_t u_h, dim_t u_w, dim_t occ) {
301 for (int ib = 0; ib < nb_ic_; ib++) {
302 out_data_t *__restrict wei_ptr = output
303 + (((u_h * w_alpha_ + u_w) * oc_chunks + occ)
304 * nb_ic_
305 + ib)
306 * oc2_block_ * ic_block_ * oc_block_;
307 int wei_offset = 0;
308 for_(int i = 0; i < ic_block_; i++)
309 for (int ob2 = 0; ob2 < oc2_block_; ob2++) {
310 for (int o = 0; o < oc_block_; o++) {
311 const int icp = ib * ic_block_ + i;
312 const int ocp = occ * oc2_block_ * oc_block_
313 + ob2 * oc_block_ + o;
314
315 const int src_offset
316 = u_h * w_alpha_ * ic_ * oc_
317 + u_w * ic_ * oc_ + icp * oc_ + ocp;
318 wei_ptr[wei_offset + o] = tmp_wei[src_offset];
319 }
320 wei_offset += oc_block_;
321 }
322 }
323 });
324 }
325
326 void reorder_to_OBaaIBOIio(out_data_t *__restrict output,
327 const out_data_t *__restrict tmp_wei) const {
328 const int ic_chunks = nb_ic_ / ic2_block_;
329 const int oc_chunks = nb_oc_ / oc2_block_;
330 parallel_nd(oc_chunks, w_alpha_, w_alpha_,
331 [&](dim_t occ, dim_t u_h, dim_t u_w) {
332 for_(int icc = 0; icc < ic_chunks; icc++)
333 for (int ob = 0; ob < oc2_block_; ob++) {
334 const int ocp = (occ * oc2_block_ + ob) * oc_block_;
335 for_(int ib = 0; ib < ic2_block_; ib++)
336 for (int i = 0; i < ic_block_; i++) {
337 const int icp
338 = (icc * ic2_block_ + ib) * ic_block_ + i;
339
340 const int src_offset = u_h * w_alpha_ * ic_ * oc_
341 + u_w * ic_ * oc_ + icp * oc_ + ocp;
342 const int wei_offset
343 = ((((((occ * w_alpha_ + u_h) * w_alpha_
344 + u_w) * ic_chunks
345 + icc) * oc2_block_
346 + ob) * ic2_block_
347 + ib) * ic_block_
348 + i)
349 * oc_block_;
350 for (int o = 0; o < oc_block_; o++)
351 output[wei_offset + o]
352 = tmp_wei[src_offset + o];
353 }
354 }
355 });
356 }
357
358 status_t execute(const exec_ctx_t &ctx) const override {
359 auto input = CTX_IN_MEM(const in_data_t *, DNNL_ARG_FROM);
360 auto output = CTX_OUT_MEM(out_data_t *, DNNL_ARG_TO);
361
362 auto wspace = (in_data_t * __restrict) ctx.get_scratchpad_grantor()
363 .template get<void>(memory_tracking::names::
364 key_reorder_wino_transform_space);
365 auto tmp_wei = (out_data_t * __restrict) ctx.get_scratchpad_grantor()
366 .template get<void>(memory_tracking::names::
367 key_reorder_wino_plain);
368
369 DEFINE_SCALES_BUFFER(oscales);
370
371 transform(tmp_wei, input, wspace, oscales);
372
373 /* reorder to winograd domain */
374 switch (wino_format_) {
375 case wino_memory_format_t::wino_wei_aaOio:
376 reorder_to_aaOio(output, tmp_wei);
377 break;
378 case wino_memory_format_t::wino_wei_aaOBiOo:
379 reorder_to_aaOBiOo(output, tmp_wei);
380 break;
381 case wino_memory_format_t::wino_wei_OBaaIBOIio:
382 reorder_to_OBaaIBOIio(output, tmp_wei);
383 break;
384 default: assert(!"Unknown wino format"); break;
385 }
386
387 return status::success;
388 }
389
390 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
391 int r_, w_alpha_;
392 dim_t ic_, oc_, or_ic_, or_oc_, kh_, kw_;
393 dim_t oc_block_, ic_block_, oc2_block_, ic2_block_;
394 float adj_scale_;
395 dim_t nb_oc_, nb_ic_;
396 wino_memory_format_t wino_format_;
397 int size_wino_wei_;
398 int size_wspace_thr_;
399 int work_amount_;
400};
401
402} // namespace x64
403} // namespace cpu
404} // namespace impl
405} // namespace dnnl
406
407#endif
408