1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <iterator>
18
19#include <float.h>
20#include <math.h>
21#include <stdio.h>
22#include <stdlib.h>
23
24#include "oneapi/dnnl/dnnl.h"
25
26#include "dnnl_common.hpp"
27#include "dnnl_memory.hpp"
28
29#include "binary/binary.hpp"
30#include "conv/conv_dw_fusion.hpp"
31
32namespace conv_dw_fusion {
33
34dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
35 const prb_t *prb = init_pd_args.prb;
36
37 auto src_d = dnn_mem_t::init_md(
38 prb->ndims, prb->src_dims().data(), prb->cfg[SRC].dt, prb->stag);
39 auto wei_d = dnn_mem_t::init_md(prb->ndims + prb->has_groups,
40 prb->wei_dims().data(), prb->cfg[WEI].dt, prb->wtag);
41 auto bia_d = dnn_mem_t::init_md(
42 1, prb->bia_dims().data(), prb->cfg[BIA].dt, tag::any);
43 auto dst_d = dnn_mem_t::init_md(
44 prb->ndims, prb->dst_dims().data(), prb->cfg[DST].dt, prb->dtag);
45
46 dnnl_alg_kind_t alg = dnnl_convolution_direct;
47 if (prb->alg == alg_t::WINO) alg = dnnl_convolution_winograd;
48 if (prb->alg == alg_t::AUTO) alg = dnnl_convolution_auto;
49
50 attr_args_t attr_args;
51
52 auto wei_scale = prb->attr.scales.get(DNNL_ARG_WEIGHTS);
53 if (wei_scale.policy == policy_t::PER_OC) {
54 auto wei_mask = prb->has_groups ? 2 : 1;
55 attr_args.prepare_scales(prb->attr, DNNL_ARG_WEIGHTS, prb->wei_scales,
56 prb->oc, wei_mask);
57 }
58 const auto dw_bia_dt = prb->dir == FWD_B ? dnnl_f32 : dnnl_data_type_undef;
59 attr_args.prepare_dw_post_op(prb->attr, prb->cfg[WEI].dt, dw_bia_dt);
60 attr_args.prepare_post_ops_mds(
61 prb->attr, prb->ndims, prb->dst_dims().data());
62 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
63 create_dnnl_attr(prb->attr, attr_args));
64
65 switch (prb->dir) {
66 case FWD_D:
67 case FWD_B:
68 case FWD_I:
69 if (prb->dir != FWD_B) bia_d.reset(nullptr);
70 DNN_SAFE_STATUS(dnnl_convolution_forward_primitive_desc_create(
71 &init_pd_args.pd, init_pd_args.engine,
72 prb->dir == FWD_I ? dnnl_forward_inference
73 : dnnl_forward_training,
74 alg, src_d, wei_d, bia_d, dst_d, prb->strides().data(),
75 prb->dilations().data(), prb->padding().data(),
76 prb->padding_r().data(), dnnl_attr));
77 break;
78 case BWD_D:
79 DNN_SAFE_STATUS(
80 dnnl_convolution_backward_data_primitive_desc_create(
81 &init_pd_args.pd, init_pd_args.engine, alg, src_d,
82 wei_d, dst_d, prb->strides().data(),
83 prb->dilations().data(), prb->padding().data(),
84 prb->padding_r().data(), init_pd_args.hint,
85 dnnl_attr));
86 break;
87 case BWD_W:
88 case BWD_WB:
89 if (prb->dir == BWD_W) bia_d.reset(nullptr);
90 DNN_SAFE_STATUS(
91 dnnl_convolution_backward_weights_primitive_desc_create(
92 &init_pd_args.pd, init_pd_args.engine, alg, src_d,
93 wei_d, bia_d, dst_d, prb->strides().data(),
94 prb->dilations().data(), prb->padding().data(),
95 prb->padding_r().data(), init_pd_args.hint,
96 dnnl_attr));
97 break;
98 default: DNN_SAFE_STATUS(dnnl_invalid_arguments);
99 }
100
101 // TODO: add query in od fir accum type.
102 //DNN_SAFE_STATUS(cd.accum_data_type == prb->cfg[ACC].dt
103 // ? dnnl_success
104 // : dnnl_unimplemented);
105 return dnnl_success;
106}
107
108std::unique_ptr<prb_t> get_first_conv_prb(const prb_t *prb) {
109 const auto &po = prb->attr.post_ops;
110 int fusion_index = po.convolution_index();
111
112 attr_t attr;
113 for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) {
114 auto sc = prb->attr.scales.get(arg);
115 if (!sc.is_def()) attr.scales.set(arg, sc);
116 }
117
118 for (int i = 0; i < fusion_index; ++i) {
119 attr.post_ops.entry.push_back(prb->attr.post_ops.entry[i]);
120 }
121
122 return std::unique_ptr<prb_t>(new prb_t((desc_t)*prb, prb->dir, prb->cfg,
123 prb->stag, prb->wtag, tag::any, prb->alg, attr, prb->ctx_init,
124 prb->ctx_exe, prb->mb));
125}
126
127std::unique_ptr<prb_t> get_fused_conv_prb(const prb_t *prb) {
128 const auto &po = prb->attr.post_ops;
129 int fusion_index = po.convolution_index();
130 if (fusion_index == -1) return nullptr;
131 const auto &fused_conv_po = po.entry[fusion_index].convolution;
132
133 attr_t fusion_attr;
134 // dw_conv src_scale = 1x1_conv dst_scale
135 if (!prb->attr.scales.get(DNNL_ARG_DST).is_def())
136 fusion_attr.scales.set(
137 DNNL_ARG_SRC, prb->attr.scales.get(DNNL_ARG_DST));
138 if (!fused_conv_po.wei_scale.is_def())
139 fusion_attr.scales.set(DNNL_ARG_WEIGHTS, fused_conv_po.wei_scale);
140 if (!fused_conv_po.dst_scale.is_def())
141 fusion_attr.scales.set(DNNL_ARG_DST, fused_conv_po.dst_scale);
142
143 for (int i = fusion_index + 1; i < po.len(); ++i) {
144 fusion_attr.post_ops.entry.push_back(prb->attr.post_ops.entry[i]);
145 }
146
147 const auto f32 = dnnl_f32;
148 std::stringstream dw_cfg_ss;
149 if (prb->cfg[DST].dt == f32 && prb->cfg[WEI].dt == f32
150 && fused_conv_po.dst_dt == f32)
151 dw_cfg_ss << prb->cfg[DST].dt; // f32 is a single name
152 else // else have all three dt in cfg name
153 dw_cfg_ss << prb->cfg[DST].dt << prb->cfg[WEI].dt
154 << fused_conv_po.dst_dt;
155 auto p_dw_cfg = conv::str2cfg(dw_cfg_ss.str().c_str());
156
157 const auto kernel = fused_conv_po.kernel;
158 const auto stride = fused_conv_po.stride;
159 const auto padding = fused_conv_po.padding;
160 bool is_3d = prb->ndims >= 5;
161 bool is_2d = prb->ndims >= 4;
162
163 desc_t cd {0};
164 cd.g = prb->oc;
165 cd.mb = prb->mb;
166 cd.ic = prb->oc;
167 cd.id = is_3d ? prb->od : 1;
168 cd.ih = is_2d ? prb->oh : 1;
169 cd.iw = prb->ow;
170 cd.oc = prb->oc;
171 cd.kd = is_3d ? kernel : 1;
172 cd.kh = is_2d ? kernel : 1;
173 cd.kw = kernel;
174 cd.sd = is_3d ? stride : 1;
175 cd.sh = is_2d ? stride : 1;
176 cd.sw = stride;
177 cd.pd = is_3d ? padding : 0;
178 cd.ph = is_2d ? padding : 0;
179 cd.pw = padding;
180 // Not following standard convolution formula for output shapes since
181 // right/top padding might be greated than left/top one.
182 cd.od = is_3d ? div_up(cd.id, stride) : 1;
183 cd.oh = is_2d ? div_up(cd.ih, stride) : 1;
184 cd.ow = div_up(cd.iw, stride);
185
186 cd.has_groups = true;
187 cd.ndims = prb->ndims;
188 cd.init_pad_r();
189
190 return std::unique_ptr<prb_t>(new prb_t(cd, prb->dir, p_dw_cfg, tag::any,
191 tag::any, prb->dtag, alg_t::DIRECT, fusion_attr, prb->ctx_init,
192 prb->ctx_exe, prb->mb));
193}
194
195void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
196 skip_unimplemented_data_type(
197 {prb->cfg[SRC].dt, prb->cfg[WEI].dt, prb->cfg[DST].dt}, prb->dir,
198 res);
199 skip_unimplemented_sum_po(prb->attr, res);
200
201 // GPU does not support depthwise fusion
202 if (is_gpu() && prb->attr.post_ops.convolution_index() != -1) {
203 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
204 return;
205 }
206}
207
208int doit(const prb_t *prb, res_t *res) {
209 if (bench_mode == LIST) return res->state = LISTED, OK;
210
211 conv_dw_fusion::skip_unimplemented_prb(prb, res);
212 if (res->state == SKIPPED) return OK;
213
214 // Original problem with fusion attributes
215 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
216 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN);
217 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
218
219 auto const_pd = query_pd(prim);
220
221 if (prb->alg == alg_t::AUTO)
222 prb->alg = conv::alg_kind2alg(query_alg_kind(const_pd));
223 prb->cfg = auto_cfg(prb->alg, prb->cfg);
224
225 const auto &src_md = prb->dir == BWD_D
226 ? query_md(const_pd, DNNL_ARG_DIFF_SRC)
227 : query_md(const_pd, DNNL_ARG_SRC);
228 const auto &wei_md = prb->dir & FLAG_WEI
229 ? query_md(const_pd, DNNL_ARG_DIFF_WEIGHTS)
230 : query_md(const_pd, DNNL_ARG_WEIGHTS);
231 const auto &bia_md = prb->dir & FLAG_WEI
232 ? query_md(const_pd, DNNL_ARG_DIFF_BIAS)
233 : query_md(const_pd, DNNL_ARG_BIAS);
234 const auto &dst_md = prb->dir & FLAG_BWD
235 ? query_md(const_pd, DNNL_ARG_DIFF_DST)
236 : query_md(const_pd, DNNL_ARG_DST);
237 const auto &fused_wei_md = prb->dir & FLAG_WEI
238 ? query_md(
239 const_pd, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DIFF_WEIGHTS)
240 : query_md(const_pd, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS);
241 const auto &fused_bia_md = prb->dir & FLAG_WEI
242 ? query_md(const_pd, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DIFF_BIAS)
243 : query_md(const_pd, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS);
244 const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD);
245
246 const auto &test_engine = get_test_engine();
247 const auto &ref_engine = get_cpu_engine();
248
249 dnn_mem_t src_dt(src_md, test_engine);
250 dnn_mem_t wei_dt(wei_md, test_engine);
251 dnn_mem_t bia_dt(bia_md, test_engine);
252 dnn_mem_t dst_dt(dst_md, test_engine);
253 dnn_mem_t fused_wei_dt(fused_wei_md, test_engine);
254 dnn_mem_t fused_bia_dt(fused_bia_md, test_engine);
255 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
256
257 const auto fp = dnnl_f32;
258 dnn_mem_t src_fp(src_md, fp, tag::abx, ref_engine);
259 dnn_mem_t wei_fp(wei_md, fp, tag::abx, ref_engine);
260 dnn_mem_t bia_fp(bia_md, fp, tag::x, ref_engine);
261 dnn_mem_t dst_fp(dst_md, fp, tag::abx, ref_engine);
262 dnn_mem_t fused_wei_fp(fused_wei_md, fp, tag::abx, ref_engine);
263 dnn_mem_t fused_bia_fp(fused_bia_md, fp, tag::x, ref_engine);
264
265 std::vector<dnn_mem_t> binary_po_dt;
266 std::vector<int> binary_po_args;
267
268 // Current filling doesn't work for fused_wei due to relying on prb values,
269 // which are different for fused conv. This can be fixed later by relying
270 // on md values, rather than prb desc ones.
271 // Filling for this problem is done below.
272 // TODO: fix this if irritates.
273
274 // Fill first convolution
275 std::unique_ptr<prb_t> p0 = get_first_conv_prb(prb);
276
277 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim0;
278 SAFE(init_prim(prim0, init_pd, p0.get(), res, FLAG_FWD, nullptr,
279 /* is_service_prim = */ true),
280 WARN);
281 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
282
283 auto const_pd0 = query_pd(prim0);
284
285 if (p0->alg == alg_t::AUTO)
286 p0->alg = conv::alg_kind2alg(query_alg_kind(const_pd0));
287 p0->cfg = auto_cfg(p0->alg, p0->cfg);
288
289 const auto &src_md0 = p0->dir == BWD_D
290 ? query_md(const_pd0, DNNL_ARG_DIFF_SRC)
291 : query_md(const_pd0, DNNL_ARG_SRC);
292 const auto &wei_md0 = p0->dir & FLAG_WEI
293 ? query_md(const_pd0, DNNL_ARG_DIFF_WEIGHTS)
294 : query_md(const_pd0, DNNL_ARG_WEIGHTS);
295 const auto &bia_md0 = p0->dir & FLAG_WEI
296 ? query_md(const_pd0, DNNL_ARG_DIFF_BIAS)
297 : query_md(const_pd0, DNNL_ARG_BIAS);
298 const auto &dst_md0 = p0->dir & FLAG_BWD
299 ? query_md(const_pd0, DNNL_ARG_DIFF_DST)
300 : query_md(const_pd0, DNNL_ARG_DST);
301 const auto &scratchpad_md0 = query_md(const_pd0, DNNL_ARG_SCRATCHPAD);
302
303 dnn_mem_t src_dt0(src_md0, test_engine);
304 dnn_mem_t wei_dt0(wei_md0, test_engine);
305 dnn_mem_t bia_dt0(bia_md0, test_engine);
306 dnn_mem_t dst_dt0(dst_md0, test_engine);
307 dnn_mem_t scratchpad_dt0(scratchpad_md0, test_engine);
308
309 dnn_mem_t src_fp0(src_md0, fp, tag::abx, ref_engine);
310 dnn_mem_t wei_fp0(wei_md0, fp, tag::abx, ref_engine);
311 dnn_mem_t bia_fp0(bia_md0, fp, tag::x, ref_engine);
312 dnn_mem_t dst_fp0(dst_md0, fp, tag::abx, ref_engine);
313
314 std::vector<dnn_mem_t> binary_po_fp0, binary_po_dt0;
315 std::vector<int> binary_po_args0;
316 SAFE(binary::setup_binary_po(
317 const_pd0, binary_po_args0, binary_po_dt0, binary_po_fp0),
318 WARN);
319
320 dnn_mem_t src_scales_dt0, src_scales_fp0;
321 dnn_mem_t wei_scales_dt0, wei_scales_fp0;
322 dnn_mem_t dst_scales_dt0, dst_scales_fp0;
323
324 const int src_mask = attr_t::get_default_mask(
325 prb->attr.scales.get(DNNL_ARG_SRC).policy);
326 const int wei_mask = attr_t::get_default_mask(
327 prb->attr.scales.get(DNNL_ARG_WEIGHTS).policy, DNNL_ARG_WEIGHTS);
328 const int dst_mask = attr_t::get_default_mask(
329 prb->attr.scales.get(DNNL_ARG_DST).policy);
330 maybe_prepare_runtime_scales_v2(src_scales_dt0, src_scales_fp0,
331 prb->attr.scales.get(DNNL_ARG_SRC),
332 prb->desc_nelems(DNNL_ARG_SRC, src_mask), prb->src_scales);
333 maybe_prepare_runtime_scales_v2(wei_scales_dt0, wei_scales_fp0,
334 prb->attr.scales.get(DNNL_ARG_WEIGHTS),
335 prb->desc_nelems(DNNL_ARG_WEIGHTS, wei_mask), prb->wei_scales);
336 maybe_prepare_runtime_scales_v2(dst_scales_dt0, dst_scales_fp0,
337 prb->attr.scales.get(DNNL_ARG_DST),
338 prb->desc_nelems(DNNL_ARG_DST, dst_mask), prb->dst_scales);
339
340 SAFE(conv::fill_src(p0.get(), src_dt0, src_fp0, res), WARN);
341 SAFE(conv::fill_wei(p0.get(), wei_dt0, wei_fp0, res), WARN);
342 SAFE(conv::fill_bia(p0.get(), bia_dt0, bia_fp0, res), WARN);
343 SAFE(conv::fill_dst(p0.get(), dst_dt0, dst_fp0, res), WARN);
344
345 // Fill next convolution
346 std::unique_ptr<prb_t> p1 = get_fused_conv_prb(prb);
347 if (!p1) SAFE(FAIL, CRIT);
348
349 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim1;
350 SAFE(init_prim(prim1, init_pd, p1.get(), res, FLAG_FWD, nullptr,
351 /* is_service_prim = */ true),
352 WARN);
353 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
354
355 auto const_pd1 = query_pd(prim1);
356
357 if (p1->alg == alg_t::AUTO)
358 p1->alg = conv::alg_kind2alg(query_alg_kind(const_pd1));
359 p1->cfg = auto_cfg(p1->alg, p1->cfg);
360
361 const auto &src_md1 = prb->dir == BWD_D
362 ? query_md(const_pd1, DNNL_ARG_DIFF_SRC)
363 : query_md(const_pd1, DNNL_ARG_SRC);
364 const auto &wei_md1 = prb->dir & FLAG_WEI
365 ? query_md(const_pd1, DNNL_ARG_DIFF_WEIGHTS)
366 : query_md(const_pd1, DNNL_ARG_WEIGHTS);
367
368 const auto &bia_md1 = prb->dir & FLAG_WEI
369 ? query_md(const_pd1, DNNL_ARG_DIFF_BIAS)
370 : query_md(const_pd1, DNNL_ARG_BIAS);
371 const auto &dst_md1 = prb->dir & FLAG_BWD
372 ? query_md(const_pd1, DNNL_ARG_DIFF_DST)
373 : query_md(const_pd1, DNNL_ARG_DST);
374 const auto &scratchpad_md1 = query_md(const_pd, DNNL_ARG_SCRATCHPAD);
375
376 dnn_mem_t src_dt1(src_md1, test_engine);
377 dnn_mem_t wei_dt1(wei_md1, test_engine);
378 dnn_mem_t bia_dt1(bia_md1, test_engine);
379 dnn_mem_t dst_dt1(dst_md1, test_engine);
380 dnn_mem_t scratchpad_dt1(scratchpad_md1, test_engine);
381
382 dnn_mem_t wei_fp1(wei_md1, fp, tag::abx, ref_engine);
383 dnn_mem_t bia_fp1(bia_md1, fp, tag::x, ref_engine);
384 dnn_mem_t dst_fp1(dst_md1, fp, tag::abx, ref_engine);
385
386 std::vector<dnn_mem_t> binary_po_fp1, binary_po_dt1;
387 std::vector<int> binary_po_args1;
388 SAFE(binary::setup_binary_po(
389 const_pd1, binary_po_args1, binary_po_dt1, binary_po_fp1),
390 WARN);
391
392 dnn_mem_t wei_scales_dt1, wei_scales_fp1;
393 dnn_mem_t dst_scales_dt1, dst_scales_fp1;
394
395 int dw_wei_mask = attr_t::get_default_mask(
396 p1->attr.scales.get(DNNL_ARG_WEIGHTS).policy, DNNL_ARG_WEIGHTS);
397 if (p1->has_groups) dw_wei_mask = (1 << dw_wei_mask) + 1;
398 const int dw_dst_mask = attr_t::get_default_mask(
399 p1->attr.scales.get(DNNL_ARG_DST).policy);
400 maybe_prepare_runtime_scales_v2(wei_scales_dt1, wei_scales_fp1,
401 p1->attr.scales.get(DNNL_ARG_WEIGHTS),
402 p1->desc_nelems(DNNL_ARG_WEIGHTS, dw_wei_mask), p1->wei_scales);
403 maybe_prepare_runtime_scales_v2(dst_scales_dt1, dst_scales_fp1,
404 p1->attr.scales.get(DNNL_ARG_DST),
405 p1->desc_nelems(DNNL_ARG_DST, dw_dst_mask), p1->dst_scales);
406
407 SAFE(conv::fill_wei(p1.get(), wei_dt1, wei_fp1, res), WARN);
408 SAFE(conv::fill_bia(p1.get(), bia_dt1, bia_fp1, res), WARN);
409 SAFE(conv::fill_dst(p1.get(), dst_dt1, dst_fp1, res), WARN);
410
411 // TODO: fix this if irritates.
412 // SAFE(conv::fill_src(prb, src_dt, src_fp, res), WARN);
413 // SAFE(conv::fill_wei(prb, wei_dt, wei_fp, res), WARN);
414 // SAFE(conv::fill_bia(prb, bia_dt, bia_fp, res), WARN);
415 // SAFE(conv::fill_dst(prb, dst_dt, dst_fp, res), WARN);
416 // SAFE(conv::fill_wei(prb, fused_wei_dt, fused_wei_fp, res), WARN);
417 // SAFE(conv::fill_bia(prb, fused_bia_dt, fused_bia_fp, res), WARN);
418 // Work around for the issue above
419 SAFE(src_dt.reorder(src_fp0), WARN);
420 SAFE(wei_dt.reorder(wei_fp0), WARN);
421 if (bia_dt.dt() != dnnl_data_type_undef)
422 SAFE(bia_dt.reorder(bia_fp0), WARN);
423 SAFE(dst_dt.reorder(dst_fp1), WARN);
424 SAFE(fused_wei_dt.reorder(wei_fp1), WARN);
425 if (fused_bia_dt.dt() != dnnl_data_type_undef)
426 SAFE(fused_bia_dt.reorder(bia_fp1), WARN);
427
428 args_t args, args0, args1, ref_args;
429
430 if (prb->dir & FLAG_FWD) {
431 args0.set(DNNL_ARG_SRC, src_dt0);
432 args0.set(DNNL_ARG_WEIGHTS, wei_dt0);
433 args0.set(DNNL_ARG_BIAS, bia_dt0);
434 args0.set(DNNL_ARG_DST, dst_dt0);
435 args0.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt0);
436 args0.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_dt0);
437 args0.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt0);
438 args0.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt0);
439 args0.set(binary_po_args0, binary_po_dt0);
440
441 SAFE(execute_and_wait(prim0, args0), WARN);
442 SAFE(src_dt1.reorder(dst_dt0), WARN);
443
444 args1.set(DNNL_ARG_SRC, src_dt1);
445 args1.set(DNNL_ARG_WEIGHTS, wei_dt1);
446 args1.set(DNNL_ARG_BIAS, bia_dt1);
447 args1.set(DNNL_ARG_DST, dst_dt1);
448 args1.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, dst_scales_dt0);
449 args1.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_dt1);
450 args1.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt1);
451 args1.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt1);
452 args1.set(binary_po_args1, binary_po_dt1);
453
454 SAFE(execute_and_wait(prim1, args1), WARN);
455
456 // Reverse engineer binary post-ops indices from second conv and update
457 // them in-place to follow fused conv enumaration.
458 const int dw_idx = prb->attr.post_ops.convolution_index();
459 const auto update_bin_po_args1_indices = [&](size_t i) {
460 auto &b = binary_po_args1[i];
461 const int orig_idx = b / DNNL_ARG_ATTR_MULTIPLE_POST_OP_BASE - 1;
462 b = DNNL_ARG_ATTR_MULTIPLE_POST_OP(orig_idx + dw_idx + 1)
463 | DNNL_ARG_SRC_1;
464 };
465 for (size_t i = 0; i < binary_po_dt1.size(); ++i)
466 update_bin_po_args1_indices(i);
467
468 // As memory is not allowed to be copied, and binary post-op memories
469 // are read-only, we move them to main convolution execution and adjust
470 // arg indices to follow the library API.
471
472 // Move the content to binary_po_dt from separate convs.
473 std::move(binary_po_dt0.begin(), binary_po_dt0.end(),
474 std::back_inserter(binary_po_dt));
475 std::move(binary_po_dt1.begin(), binary_po_dt1.end(),
476 std::back_inserter(binary_po_dt));
477 // Move the content to binary_po_args from separate convs.
478 std::move(binary_po_args0.begin(), binary_po_args0.end(),
479 std::back_inserter(binary_po_args));
480 std::move(binary_po_args1.begin(), binary_po_args1.end(),
481 std::back_inserter(binary_po_args));
482
483 args.set(DNNL_ARG_SRC, src_dt);
484 args.set(DNNL_ARG_WEIGHTS, wei_dt);
485 args.set(DNNL_ARG_BIAS, bia_dt);
486 args.set(DNNL_ARG_DST, dst_dt);
487 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt0);
488 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_dt0);
489 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt0);
490 args.set(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS, fused_wei_dt);
491 args.set(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS, fused_bia_dt);
492 args.set(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_ATTR_SCALES
493 | DNNL_ARG_WEIGHTS,
494 wei_scales_dt1);
495 args.set(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST,
496 dst_scales_dt1);
497 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
498 args.set(binary_po_args, binary_po_dt);
499
500 SAFE(execute_and_wait(prim, args, res), WARN);
501
502 if (is_bench_mode(CORR)) {
503 compare::compare_t cmp;
504 cmp.set_data_kind(DST);
505 // Used p1 to avoid writing separate compare function. Compare uses
506 // prb->cfg which can be u8s8u8 while after fusion it may be u8s8s8,
507 // thus, compare() will saturate values which is not correct.
508 conv::setup_cmp(cmp, p1.get(), DST, ref_args);
509
510 dnn_mem_t dst_fused(dst_dt, fp, tag::abx, test_engine);
511 dnn_mem_t dst_unfused(dst_dt1, fp, tag::abx, test_engine);
512
513 cmp.compare(dst_unfused, dst_fused, prb->attr, res);
514 }
515 } else {
516 assert(!"Backward is not supported");
517 SAFE(FAIL, CRIT);
518 }
519
520 return measure_perf(prb->ctx_exe, res, prim, args);
521}
522
523} // namespace conv_dw_fusion
524