1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <iterator> |
18 | |
19 | #include <float.h> |
20 | #include <math.h> |
21 | #include <stdio.h> |
22 | #include <stdlib.h> |
23 | |
24 | #include "oneapi/dnnl/dnnl.h" |
25 | |
26 | #include "dnnl_common.hpp" |
27 | #include "dnnl_memory.hpp" |
28 | |
29 | #include "binary/binary.hpp" |
30 | #include "conv/conv_dw_fusion.hpp" |
31 | |
32 | namespace conv_dw_fusion { |
33 | |
34 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
35 | const prb_t *prb = init_pd_args.prb; |
36 | |
37 | auto src_d = dnn_mem_t::init_md( |
38 | prb->ndims, prb->src_dims().data(), prb->cfg[SRC].dt, prb->stag); |
39 | auto wei_d = dnn_mem_t::init_md(prb->ndims + prb->has_groups, |
40 | prb->wei_dims().data(), prb->cfg[WEI].dt, prb->wtag); |
41 | auto bia_d = dnn_mem_t::init_md( |
42 | 1, prb->bia_dims().data(), prb->cfg[BIA].dt, tag::any); |
43 | auto dst_d = dnn_mem_t::init_md( |
44 | prb->ndims, prb->dst_dims().data(), prb->cfg[DST].dt, prb->dtag); |
45 | |
46 | dnnl_alg_kind_t alg = dnnl_convolution_direct; |
47 | if (prb->alg == alg_t::WINO) alg = dnnl_convolution_winograd; |
48 | if (prb->alg == alg_t::AUTO) alg = dnnl_convolution_auto; |
49 | |
50 | attr_args_t attr_args; |
51 | |
52 | auto wei_scale = prb->attr.scales.get(DNNL_ARG_WEIGHTS); |
53 | if (wei_scale.policy == policy_t::PER_OC) { |
54 | auto wei_mask = prb->has_groups ? 2 : 1; |
55 | attr_args.prepare_scales(prb->attr, DNNL_ARG_WEIGHTS, prb->wei_scales, |
56 | prb->oc, wei_mask); |
57 | } |
58 | const auto dw_bia_dt = prb->dir == FWD_B ? dnnl_f32 : dnnl_data_type_undef; |
59 | attr_args.prepare_dw_post_op(prb->attr, prb->cfg[WEI].dt, dw_bia_dt); |
60 | attr_args.prepare_post_ops_mds( |
61 | prb->attr, prb->ndims, prb->dst_dims().data()); |
62 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
63 | create_dnnl_attr(prb->attr, attr_args)); |
64 | |
65 | switch (prb->dir) { |
66 | case FWD_D: |
67 | case FWD_B: |
68 | case FWD_I: |
69 | if (prb->dir != FWD_B) bia_d.reset(nullptr); |
70 | DNN_SAFE_STATUS(dnnl_convolution_forward_primitive_desc_create( |
71 | &init_pd_args.pd, init_pd_args.engine, |
72 | prb->dir == FWD_I ? dnnl_forward_inference |
73 | : dnnl_forward_training, |
74 | alg, src_d, wei_d, bia_d, dst_d, prb->strides().data(), |
75 | prb->dilations().data(), prb->padding().data(), |
76 | prb->padding_r().data(), dnnl_attr)); |
77 | break; |
78 | case BWD_D: |
79 | DNN_SAFE_STATUS( |
80 | dnnl_convolution_backward_data_primitive_desc_create( |
81 | &init_pd_args.pd, init_pd_args.engine, alg, src_d, |
82 | wei_d, dst_d, prb->strides().data(), |
83 | prb->dilations().data(), prb->padding().data(), |
84 | prb->padding_r().data(), init_pd_args.hint, |
85 | dnnl_attr)); |
86 | break; |
87 | case BWD_W: |
88 | case BWD_WB: |
89 | if (prb->dir == BWD_W) bia_d.reset(nullptr); |
90 | DNN_SAFE_STATUS( |
91 | dnnl_convolution_backward_weights_primitive_desc_create( |
92 | &init_pd_args.pd, init_pd_args.engine, alg, src_d, |
93 | wei_d, bia_d, dst_d, prb->strides().data(), |
94 | prb->dilations().data(), prb->padding().data(), |
95 | prb->padding_r().data(), init_pd_args.hint, |
96 | dnnl_attr)); |
97 | break; |
98 | default: DNN_SAFE_STATUS(dnnl_invalid_arguments); |
99 | } |
100 | |
101 | // TODO: add query in od fir accum type. |
102 | //DNN_SAFE_STATUS(cd.accum_data_type == prb->cfg[ACC].dt |
103 | // ? dnnl_success |
104 | // : dnnl_unimplemented); |
105 | return dnnl_success; |
106 | } |
107 | |
108 | std::unique_ptr<prb_t> get_first_conv_prb(const prb_t *prb) { |
109 | const auto &po = prb->attr.post_ops; |
110 | int fusion_index = po.convolution_index(); |
111 | |
112 | attr_t attr; |
113 | for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { |
114 | auto sc = prb->attr.scales.get(arg); |
115 | if (!sc.is_def()) attr.scales.set(arg, sc); |
116 | } |
117 | |
118 | for (int i = 0; i < fusion_index; ++i) { |
119 | attr.post_ops.entry.push_back(prb->attr.post_ops.entry[i]); |
120 | } |
121 | |
122 | return std::unique_ptr<prb_t>(new prb_t((desc_t)*prb, prb->dir, prb->cfg, |
123 | prb->stag, prb->wtag, tag::any, prb->alg, attr, prb->ctx_init, |
124 | prb->ctx_exe, prb->mb)); |
125 | } |
126 | |
127 | std::unique_ptr<prb_t> get_fused_conv_prb(const prb_t *prb) { |
128 | const auto &po = prb->attr.post_ops; |
129 | int fusion_index = po.convolution_index(); |
130 | if (fusion_index == -1) return nullptr; |
131 | const auto &fused_conv_po = po.entry[fusion_index].convolution; |
132 | |
133 | attr_t fusion_attr; |
134 | // dw_conv src_scale = 1x1_conv dst_scale |
135 | if (!prb->attr.scales.get(DNNL_ARG_DST).is_def()) |
136 | fusion_attr.scales.set( |
137 | DNNL_ARG_SRC, prb->attr.scales.get(DNNL_ARG_DST)); |
138 | if (!fused_conv_po.wei_scale.is_def()) |
139 | fusion_attr.scales.set(DNNL_ARG_WEIGHTS, fused_conv_po.wei_scale); |
140 | if (!fused_conv_po.dst_scale.is_def()) |
141 | fusion_attr.scales.set(DNNL_ARG_DST, fused_conv_po.dst_scale); |
142 | |
143 | for (int i = fusion_index + 1; i < po.len(); ++i) { |
144 | fusion_attr.post_ops.entry.push_back(prb->attr.post_ops.entry[i]); |
145 | } |
146 | |
147 | const auto f32 = dnnl_f32; |
148 | std::stringstream dw_cfg_ss; |
149 | if (prb->cfg[DST].dt == f32 && prb->cfg[WEI].dt == f32 |
150 | && fused_conv_po.dst_dt == f32) |
151 | dw_cfg_ss << prb->cfg[DST].dt; // f32 is a single name |
152 | else // else have all three dt in cfg name |
153 | dw_cfg_ss << prb->cfg[DST].dt << prb->cfg[WEI].dt |
154 | << fused_conv_po.dst_dt; |
155 | auto p_dw_cfg = conv::str2cfg(dw_cfg_ss.str().c_str()); |
156 | |
157 | const auto kernel = fused_conv_po.kernel; |
158 | const auto stride = fused_conv_po.stride; |
159 | const auto padding = fused_conv_po.padding; |
160 | bool is_3d = prb->ndims >= 5; |
161 | bool is_2d = prb->ndims >= 4; |
162 | |
163 | desc_t cd {0}; |
164 | cd.g = prb->oc; |
165 | cd.mb = prb->mb; |
166 | cd.ic = prb->oc; |
167 | cd.id = is_3d ? prb->od : 1; |
168 | cd.ih = is_2d ? prb->oh : 1; |
169 | cd.iw = prb->ow; |
170 | cd.oc = prb->oc; |
171 | cd.kd = is_3d ? kernel : 1; |
172 | cd.kh = is_2d ? kernel : 1; |
173 | cd.kw = kernel; |
174 | cd.sd = is_3d ? stride : 1; |
175 | cd.sh = is_2d ? stride : 1; |
176 | cd.sw = stride; |
177 | cd.pd = is_3d ? padding : 0; |
178 | cd.ph = is_2d ? padding : 0; |
179 | cd.pw = padding; |
180 | // Not following standard convolution formula for output shapes since |
181 | // right/top padding might be greated than left/top one. |
182 | cd.od = is_3d ? div_up(cd.id, stride) : 1; |
183 | cd.oh = is_2d ? div_up(cd.ih, stride) : 1; |
184 | cd.ow = div_up(cd.iw, stride); |
185 | |
186 | cd.has_groups = true; |
187 | cd.ndims = prb->ndims; |
188 | cd.init_pad_r(); |
189 | |
190 | return std::unique_ptr<prb_t>(new prb_t(cd, prb->dir, p_dw_cfg, tag::any, |
191 | tag::any, prb->dtag, alg_t::DIRECT, fusion_attr, prb->ctx_init, |
192 | prb->ctx_exe, prb->mb)); |
193 | } |
194 | |
195 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
196 | skip_unimplemented_data_type( |
197 | {prb->cfg[SRC].dt, prb->cfg[WEI].dt, prb->cfg[DST].dt}, prb->dir, |
198 | res); |
199 | skip_unimplemented_sum_po(prb->attr, res); |
200 | |
201 | // GPU does not support depthwise fusion |
202 | if (is_gpu() && prb->attr.post_ops.convolution_index() != -1) { |
203 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
204 | return; |
205 | } |
206 | } |
207 | |
208 | int doit(const prb_t *prb, res_t *res) { |
209 | if (bench_mode == LIST) return res->state = LISTED, OK; |
210 | |
211 | conv_dw_fusion::skip_unimplemented_prb(prb, res); |
212 | if (res->state == SKIPPED) return OK; |
213 | |
214 | // Original problem with fusion attributes |
215 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
216 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN); |
217 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
218 | |
219 | auto const_pd = query_pd(prim); |
220 | |
221 | if (prb->alg == alg_t::AUTO) |
222 | prb->alg = conv::alg_kind2alg(query_alg_kind(const_pd)); |
223 | prb->cfg = auto_cfg(prb->alg, prb->cfg); |
224 | |
225 | const auto &src_md = prb->dir == BWD_D |
226 | ? query_md(const_pd, DNNL_ARG_DIFF_SRC) |
227 | : query_md(const_pd, DNNL_ARG_SRC); |
228 | const auto &wei_md = prb->dir & FLAG_WEI |
229 | ? query_md(const_pd, DNNL_ARG_DIFF_WEIGHTS) |
230 | : query_md(const_pd, DNNL_ARG_WEIGHTS); |
231 | const auto &bia_md = prb->dir & FLAG_WEI |
232 | ? query_md(const_pd, DNNL_ARG_DIFF_BIAS) |
233 | : query_md(const_pd, DNNL_ARG_BIAS); |
234 | const auto &dst_md = prb->dir & FLAG_BWD |
235 | ? query_md(const_pd, DNNL_ARG_DIFF_DST) |
236 | : query_md(const_pd, DNNL_ARG_DST); |
237 | const auto &fused_wei_md = prb->dir & FLAG_WEI |
238 | ? query_md( |
239 | const_pd, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DIFF_WEIGHTS) |
240 | : query_md(const_pd, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); |
241 | const auto &fused_bia_md = prb->dir & FLAG_WEI |
242 | ? query_md(const_pd, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DIFF_BIAS) |
243 | : query_md(const_pd, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS); |
244 | const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD); |
245 | |
246 | const auto &test_engine = get_test_engine(); |
247 | const auto &ref_engine = get_cpu_engine(); |
248 | |
249 | dnn_mem_t src_dt(src_md, test_engine); |
250 | dnn_mem_t wei_dt(wei_md, test_engine); |
251 | dnn_mem_t bia_dt(bia_md, test_engine); |
252 | dnn_mem_t dst_dt(dst_md, test_engine); |
253 | dnn_mem_t fused_wei_dt(fused_wei_md, test_engine); |
254 | dnn_mem_t fused_bia_dt(fused_bia_md, test_engine); |
255 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
256 | |
257 | const auto fp = dnnl_f32; |
258 | dnn_mem_t src_fp(src_md, fp, tag::abx, ref_engine); |
259 | dnn_mem_t wei_fp(wei_md, fp, tag::abx, ref_engine); |
260 | dnn_mem_t bia_fp(bia_md, fp, tag::x, ref_engine); |
261 | dnn_mem_t dst_fp(dst_md, fp, tag::abx, ref_engine); |
262 | dnn_mem_t fused_wei_fp(fused_wei_md, fp, tag::abx, ref_engine); |
263 | dnn_mem_t fused_bia_fp(fused_bia_md, fp, tag::x, ref_engine); |
264 | |
265 | std::vector<dnn_mem_t> binary_po_dt; |
266 | std::vector<int> binary_po_args; |
267 | |
268 | // Current filling doesn't work for fused_wei due to relying on prb values, |
269 | // which are different for fused conv. This can be fixed later by relying |
270 | // on md values, rather than prb desc ones. |
271 | // Filling for this problem is done below. |
272 | // TODO: fix this if irritates. |
273 | |
274 | // Fill first convolution |
275 | std::unique_ptr<prb_t> p0 = get_first_conv_prb(prb); |
276 | |
277 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim0; |
278 | SAFE(init_prim(prim0, init_pd, p0.get(), res, FLAG_FWD, nullptr, |
279 | /* is_service_prim = */ true), |
280 | WARN); |
281 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
282 | |
283 | auto const_pd0 = query_pd(prim0); |
284 | |
285 | if (p0->alg == alg_t::AUTO) |
286 | p0->alg = conv::alg_kind2alg(query_alg_kind(const_pd0)); |
287 | p0->cfg = auto_cfg(p0->alg, p0->cfg); |
288 | |
289 | const auto &src_md0 = p0->dir == BWD_D |
290 | ? query_md(const_pd0, DNNL_ARG_DIFF_SRC) |
291 | : query_md(const_pd0, DNNL_ARG_SRC); |
292 | const auto &wei_md0 = p0->dir & FLAG_WEI |
293 | ? query_md(const_pd0, DNNL_ARG_DIFF_WEIGHTS) |
294 | : query_md(const_pd0, DNNL_ARG_WEIGHTS); |
295 | const auto &bia_md0 = p0->dir & FLAG_WEI |
296 | ? query_md(const_pd0, DNNL_ARG_DIFF_BIAS) |
297 | : query_md(const_pd0, DNNL_ARG_BIAS); |
298 | const auto &dst_md0 = p0->dir & FLAG_BWD |
299 | ? query_md(const_pd0, DNNL_ARG_DIFF_DST) |
300 | : query_md(const_pd0, DNNL_ARG_DST); |
301 | const auto &scratchpad_md0 = query_md(const_pd0, DNNL_ARG_SCRATCHPAD); |
302 | |
303 | dnn_mem_t src_dt0(src_md0, test_engine); |
304 | dnn_mem_t wei_dt0(wei_md0, test_engine); |
305 | dnn_mem_t bia_dt0(bia_md0, test_engine); |
306 | dnn_mem_t dst_dt0(dst_md0, test_engine); |
307 | dnn_mem_t scratchpad_dt0(scratchpad_md0, test_engine); |
308 | |
309 | dnn_mem_t src_fp0(src_md0, fp, tag::abx, ref_engine); |
310 | dnn_mem_t wei_fp0(wei_md0, fp, tag::abx, ref_engine); |
311 | dnn_mem_t bia_fp0(bia_md0, fp, tag::x, ref_engine); |
312 | dnn_mem_t dst_fp0(dst_md0, fp, tag::abx, ref_engine); |
313 | |
314 | std::vector<dnn_mem_t> binary_po_fp0, binary_po_dt0; |
315 | std::vector<int> binary_po_args0; |
316 | SAFE(binary::setup_binary_po( |
317 | const_pd0, binary_po_args0, binary_po_dt0, binary_po_fp0), |
318 | WARN); |
319 | |
320 | dnn_mem_t src_scales_dt0, src_scales_fp0; |
321 | dnn_mem_t wei_scales_dt0, wei_scales_fp0; |
322 | dnn_mem_t dst_scales_dt0, dst_scales_fp0; |
323 | |
324 | const int src_mask = attr_t::get_default_mask( |
325 | prb->attr.scales.get(DNNL_ARG_SRC).policy); |
326 | const int wei_mask = attr_t::get_default_mask( |
327 | prb->attr.scales.get(DNNL_ARG_WEIGHTS).policy, DNNL_ARG_WEIGHTS); |
328 | const int dst_mask = attr_t::get_default_mask( |
329 | prb->attr.scales.get(DNNL_ARG_DST).policy); |
330 | maybe_prepare_runtime_scales_v2(src_scales_dt0, src_scales_fp0, |
331 | prb->attr.scales.get(DNNL_ARG_SRC), |
332 | prb->desc_nelems(DNNL_ARG_SRC, src_mask), prb->src_scales); |
333 | maybe_prepare_runtime_scales_v2(wei_scales_dt0, wei_scales_fp0, |
334 | prb->attr.scales.get(DNNL_ARG_WEIGHTS), |
335 | prb->desc_nelems(DNNL_ARG_WEIGHTS, wei_mask), prb->wei_scales); |
336 | maybe_prepare_runtime_scales_v2(dst_scales_dt0, dst_scales_fp0, |
337 | prb->attr.scales.get(DNNL_ARG_DST), |
338 | prb->desc_nelems(DNNL_ARG_DST, dst_mask), prb->dst_scales); |
339 | |
340 | SAFE(conv::fill_src(p0.get(), src_dt0, src_fp0, res), WARN); |
341 | SAFE(conv::fill_wei(p0.get(), wei_dt0, wei_fp0, res), WARN); |
342 | SAFE(conv::fill_bia(p0.get(), bia_dt0, bia_fp0, res), WARN); |
343 | SAFE(conv::fill_dst(p0.get(), dst_dt0, dst_fp0, res), WARN); |
344 | |
345 | // Fill next convolution |
346 | std::unique_ptr<prb_t> p1 = get_fused_conv_prb(prb); |
347 | if (!p1) SAFE(FAIL, CRIT); |
348 | |
349 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim1; |
350 | SAFE(init_prim(prim1, init_pd, p1.get(), res, FLAG_FWD, nullptr, |
351 | /* is_service_prim = */ true), |
352 | WARN); |
353 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
354 | |
355 | auto const_pd1 = query_pd(prim1); |
356 | |
357 | if (p1->alg == alg_t::AUTO) |
358 | p1->alg = conv::alg_kind2alg(query_alg_kind(const_pd1)); |
359 | p1->cfg = auto_cfg(p1->alg, p1->cfg); |
360 | |
361 | const auto &src_md1 = prb->dir == BWD_D |
362 | ? query_md(const_pd1, DNNL_ARG_DIFF_SRC) |
363 | : query_md(const_pd1, DNNL_ARG_SRC); |
364 | const auto &wei_md1 = prb->dir & FLAG_WEI |
365 | ? query_md(const_pd1, DNNL_ARG_DIFF_WEIGHTS) |
366 | : query_md(const_pd1, DNNL_ARG_WEIGHTS); |
367 | |
368 | const auto &bia_md1 = prb->dir & FLAG_WEI |
369 | ? query_md(const_pd1, DNNL_ARG_DIFF_BIAS) |
370 | : query_md(const_pd1, DNNL_ARG_BIAS); |
371 | const auto &dst_md1 = prb->dir & FLAG_BWD |
372 | ? query_md(const_pd1, DNNL_ARG_DIFF_DST) |
373 | : query_md(const_pd1, DNNL_ARG_DST); |
374 | const auto &scratchpad_md1 = query_md(const_pd, DNNL_ARG_SCRATCHPAD); |
375 | |
376 | dnn_mem_t src_dt1(src_md1, test_engine); |
377 | dnn_mem_t wei_dt1(wei_md1, test_engine); |
378 | dnn_mem_t bia_dt1(bia_md1, test_engine); |
379 | dnn_mem_t dst_dt1(dst_md1, test_engine); |
380 | dnn_mem_t scratchpad_dt1(scratchpad_md1, test_engine); |
381 | |
382 | dnn_mem_t wei_fp1(wei_md1, fp, tag::abx, ref_engine); |
383 | dnn_mem_t bia_fp1(bia_md1, fp, tag::x, ref_engine); |
384 | dnn_mem_t dst_fp1(dst_md1, fp, tag::abx, ref_engine); |
385 | |
386 | std::vector<dnn_mem_t> binary_po_fp1, binary_po_dt1; |
387 | std::vector<int> binary_po_args1; |
388 | SAFE(binary::setup_binary_po( |
389 | const_pd1, binary_po_args1, binary_po_dt1, binary_po_fp1), |
390 | WARN); |
391 | |
392 | dnn_mem_t wei_scales_dt1, wei_scales_fp1; |
393 | dnn_mem_t dst_scales_dt1, dst_scales_fp1; |
394 | |
395 | int dw_wei_mask = attr_t::get_default_mask( |
396 | p1->attr.scales.get(DNNL_ARG_WEIGHTS).policy, DNNL_ARG_WEIGHTS); |
397 | if (p1->has_groups) dw_wei_mask = (1 << dw_wei_mask) + 1; |
398 | const int dw_dst_mask = attr_t::get_default_mask( |
399 | p1->attr.scales.get(DNNL_ARG_DST).policy); |
400 | maybe_prepare_runtime_scales_v2(wei_scales_dt1, wei_scales_fp1, |
401 | p1->attr.scales.get(DNNL_ARG_WEIGHTS), |
402 | p1->desc_nelems(DNNL_ARG_WEIGHTS, dw_wei_mask), p1->wei_scales); |
403 | maybe_prepare_runtime_scales_v2(dst_scales_dt1, dst_scales_fp1, |
404 | p1->attr.scales.get(DNNL_ARG_DST), |
405 | p1->desc_nelems(DNNL_ARG_DST, dw_dst_mask), p1->dst_scales); |
406 | |
407 | SAFE(conv::fill_wei(p1.get(), wei_dt1, wei_fp1, res), WARN); |
408 | SAFE(conv::fill_bia(p1.get(), bia_dt1, bia_fp1, res), WARN); |
409 | SAFE(conv::fill_dst(p1.get(), dst_dt1, dst_fp1, res), WARN); |
410 | |
411 | // TODO: fix this if irritates. |
412 | // SAFE(conv::fill_src(prb, src_dt, src_fp, res), WARN); |
413 | // SAFE(conv::fill_wei(prb, wei_dt, wei_fp, res), WARN); |
414 | // SAFE(conv::fill_bia(prb, bia_dt, bia_fp, res), WARN); |
415 | // SAFE(conv::fill_dst(prb, dst_dt, dst_fp, res), WARN); |
416 | // SAFE(conv::fill_wei(prb, fused_wei_dt, fused_wei_fp, res), WARN); |
417 | // SAFE(conv::fill_bia(prb, fused_bia_dt, fused_bia_fp, res), WARN); |
418 | // Work around for the issue above |
419 | SAFE(src_dt.reorder(src_fp0), WARN); |
420 | SAFE(wei_dt.reorder(wei_fp0), WARN); |
421 | if (bia_dt.dt() != dnnl_data_type_undef) |
422 | SAFE(bia_dt.reorder(bia_fp0), WARN); |
423 | SAFE(dst_dt.reorder(dst_fp1), WARN); |
424 | SAFE(fused_wei_dt.reorder(wei_fp1), WARN); |
425 | if (fused_bia_dt.dt() != dnnl_data_type_undef) |
426 | SAFE(fused_bia_dt.reorder(bia_fp1), WARN); |
427 | |
428 | args_t args, args0, args1, ref_args; |
429 | |
430 | if (prb->dir & FLAG_FWD) { |
431 | args0.set(DNNL_ARG_SRC, src_dt0); |
432 | args0.set(DNNL_ARG_WEIGHTS, wei_dt0); |
433 | args0.set(DNNL_ARG_BIAS, bia_dt0); |
434 | args0.set(DNNL_ARG_DST, dst_dt0); |
435 | args0.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt0); |
436 | args0.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_dt0); |
437 | args0.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt0); |
438 | args0.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt0); |
439 | args0.set(binary_po_args0, binary_po_dt0); |
440 | |
441 | SAFE(execute_and_wait(prim0, args0), WARN); |
442 | SAFE(src_dt1.reorder(dst_dt0), WARN); |
443 | |
444 | args1.set(DNNL_ARG_SRC, src_dt1); |
445 | args1.set(DNNL_ARG_WEIGHTS, wei_dt1); |
446 | args1.set(DNNL_ARG_BIAS, bia_dt1); |
447 | args1.set(DNNL_ARG_DST, dst_dt1); |
448 | args1.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, dst_scales_dt0); |
449 | args1.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_dt1); |
450 | args1.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt1); |
451 | args1.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt1); |
452 | args1.set(binary_po_args1, binary_po_dt1); |
453 | |
454 | SAFE(execute_and_wait(prim1, args1), WARN); |
455 | |
456 | // Reverse engineer binary post-ops indices from second conv and update |
457 | // them in-place to follow fused conv enumaration. |
458 | const int dw_idx = prb->attr.post_ops.convolution_index(); |
459 | const auto update_bin_po_args1_indices = [&](size_t i) { |
460 | auto &b = binary_po_args1[i]; |
461 | const int orig_idx = b / DNNL_ARG_ATTR_MULTIPLE_POST_OP_BASE - 1; |
462 | b = DNNL_ARG_ATTR_MULTIPLE_POST_OP(orig_idx + dw_idx + 1) |
463 | | DNNL_ARG_SRC_1; |
464 | }; |
465 | for (size_t i = 0; i < binary_po_dt1.size(); ++i) |
466 | update_bin_po_args1_indices(i); |
467 | |
468 | // As memory is not allowed to be copied, and binary post-op memories |
469 | // are read-only, we move them to main convolution execution and adjust |
470 | // arg indices to follow the library API. |
471 | |
472 | // Move the content to binary_po_dt from separate convs. |
473 | std::move(binary_po_dt0.begin(), binary_po_dt0.end(), |
474 | std::back_inserter(binary_po_dt)); |
475 | std::move(binary_po_dt1.begin(), binary_po_dt1.end(), |
476 | std::back_inserter(binary_po_dt)); |
477 | // Move the content to binary_po_args from separate convs. |
478 | std::move(binary_po_args0.begin(), binary_po_args0.end(), |
479 | std::back_inserter(binary_po_args)); |
480 | std::move(binary_po_args1.begin(), binary_po_args1.end(), |
481 | std::back_inserter(binary_po_args)); |
482 | |
483 | args.set(DNNL_ARG_SRC, src_dt); |
484 | args.set(DNNL_ARG_WEIGHTS, wei_dt); |
485 | args.set(DNNL_ARG_BIAS, bia_dt); |
486 | args.set(DNNL_ARG_DST, dst_dt); |
487 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt0); |
488 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_dt0); |
489 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt0); |
490 | args.set(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS, fused_wei_dt); |
491 | args.set(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS, fused_bia_dt); |
492 | args.set(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_ATTR_SCALES |
493 | | DNNL_ARG_WEIGHTS, |
494 | wei_scales_dt1); |
495 | args.set(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, |
496 | dst_scales_dt1); |
497 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
498 | args.set(binary_po_args, binary_po_dt); |
499 | |
500 | SAFE(execute_and_wait(prim, args, res), WARN); |
501 | |
502 | if (is_bench_mode(CORR)) { |
503 | compare::compare_t cmp; |
504 | cmp.set_data_kind(DST); |
505 | // Used p1 to avoid writing separate compare function. Compare uses |
506 | // prb->cfg which can be u8s8u8 while after fusion it may be u8s8s8, |
507 | // thus, compare() will saturate values which is not correct. |
508 | conv::setup_cmp(cmp, p1.get(), DST, ref_args); |
509 | |
510 | dnn_mem_t dst_fused(dst_dt, fp, tag::abx, test_engine); |
511 | dnn_mem_t dst_unfused(dst_dt1, fp, tag::abx, test_engine); |
512 | |
513 | cmp.compare(dst_unfused, dst_fused, prb->attr, res); |
514 | } |
515 | } else { |
516 | assert(!"Backward is not supported" ); |
517 | SAFE(FAIL, CRIT); |
518 | } |
519 | |
520 | return measure_perf(prb->ctx_exe, res, prim, args); |
521 | } |
522 | |
523 | } // namespace conv_dw_fusion |
524 | |