1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <stdio.h>
18#include <stdlib.h>
19
20#include <random>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#include "utils/parallel.hpp"
25
26#include "dnnl_common.hpp"
27#include "dnnl_memory.hpp"
28
29#include "concat/concat.hpp"
30
31namespace concat {
32
33dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
34 const prb_t *prb = init_pd_args.prb;
35
36 std::vector<benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t>> src_d_wrappers(
37 prb->n_inputs());
38
39 for (int i_input = 0; i_input < prb->n_inputs(); ++i_input) {
40 const dims_t &i_vdims = prb->vdims[i_input];
41 src_d_wrappers[i_input] = dnn_mem_t::init_md(
42 prb->ndims, i_vdims.data(), prb->sdt, prb->stag[i_input]);
43 }
44
45 benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> dst_d {};
46 if (prb->dtag != tag::undef) {
47 dst_d = dnn_mem_t::init_md(
48 prb->ndims, prb->dst_dims.data(), prb->ddt, prb->dtag);
49 }
50
51 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
52 create_dnnl_attr(prb->attr, attr_args_t()));
53
54 std::vector<dnnl_memory_desc_t> src_d(
55 src_d_wrappers.begin(), src_d_wrappers.end());
56 init_pd_args.is_iterator_supported = false;
57 DNN_SAFE_STATUS(dnnl_concat_primitive_desc_create(&init_pd_args.pd,
58 init_pd_args.engine, dst_d, prb->n_inputs(), prb->axis,
59 src_d.data(), dnnl_attr));
60
61 return dnnl_success;
62}
63
64int fill_src(int input_idx, dnnl_data_type_t dt, dnn_mem_t &mem_dt,
65 dnn_mem_t &mem_fp) {
66 const auto nelems = mem_fp.nelems();
67 // Do fixed partitioning to have same filling for any number of threads.
68 const int64_t n_chunks = 16;
69 const int64_t chunk_size = div_up(nelems, n_chunks);
70 // Set proper range of valid values to avoid any reorders back and forth.
71 const bool s8u8_or_u8s8 = (dt == dnnl_s8 && mem_dt.dt() == dnnl_u8)
72 || (dt == dnnl_u8 && mem_dt.dt() == dnnl_s8);
73 float min_val = lowest_dt(dnnl_s8);
74 float max_val = max_dt(dnnl_u8);
75 if (s8u8_or_u8s8) {
76 min_val = lowest_dt(dnnl_u8);
77 max_val = max_dt(dnnl_s8);
78 } else if (dt == dnnl_s8 || mem_dt.dt() == dnnl_s8) {
79 max_val = max_dt(dnnl_s8);
80 } else if (dt == dnnl_u8 || mem_dt.dt() == dnnl_u8) {
81 min_val = lowest_dt(dnnl_u8);
82 }
83
84 benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) {
85 int64_t idx_start = idx_chunk * chunk_size;
86 int64_t idx_end = MIN2(idx_start + chunk_size, nelems);
87 // See eltwise.cpp for implementation details.
88 std::minstd_rand msr(input_idx * n_chunks + idx_start + 1);
89 msr.discard(1);
90 std::uniform_int_distribution<> igen(min_val, max_val);
91 // No need to round final value as it's already in needed dt.
92 for (int64_t idx = idx_start; idx < idx_end; ++idx)
93 mem_fp.set_elem(idx, (float)igen(msr));
94 });
95
96 SAFE(mem_dt.reorder(mem_fp), WARN);
97
98 return OK;
99}
100
101void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
102 skip_unimplemented_data_type({prb->sdt, prb->ddt}, prb->dir, res);
103 skip_unimplemented_sum_po(prb->attr, res);
104 skip_unimplemented_arg_scale(prb->attr, res);
105
106 // ref concat is reorder-based, hence, inherits some reorder limitations.
107 // bf16, f16 reorders on cpu supports only [bf16, f16]<->f32
108 bool valid_xf16_input
109 = IMPLICATION(prb->sdt == dnnl_bf16 || prb->sdt == dnnl_f16,
110 prb->dtag == tag::undef || prb->ddt == dnnl_f32
111 || prb->ddt == prb->sdt);
112 bool valid_xf16_output
113 = IMPLICATION((prb->ddt == dnnl_bf16 || prb->ddt == dnnl_f16)
114 && prb->dtag != tag::undef,
115 (prb->sdt == dnnl_f32 || prb->sdt == prb->ddt));
116
117 if (is_cpu() && (!valid_xf16_input || !valid_xf16_output)) {
118 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
119 return;
120 }
121}
122
123void skip_invalid_prb(const prb_t *prb, res_t *res) {}
124
125void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
126 const args_t &ref_args) {}
127
128int doit(const prb_t *prb, res_t *res) {
129 if (bench_mode == LIST) return res->state = LISTED, OK;
130
131 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
132 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN);
133 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
134
135 auto const_pd = query_pd(prim);
136
137 const auto &test_engine = get_test_engine();
138 const auto &ref_engine = get_cpu_engine();
139 const auto &dst_md = query_md(const_pd, DNNL_ARG_DST);
140 const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD);
141
142 dnn_mem_t dst_fp(dst_md, dnnl_f32, tag::abx, ref_engine);
143 dnn_mem_t dst_dt(dst_md, test_engine);
144 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
145
146 args_t args, ref_args;
147
148 args.set(DNNL_ARG_DST, dst_dt);
149 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
150
151 std::vector<dnn_mem_t> src_fp, src_dt, scales;
152 src_fp.reserve(prb->n_inputs());
153 src_dt.reserve(prb->n_inputs());
154 scales.resize(prb->n_inputs());
155
156 for (int i_input = 0; i_input < prb->n_inputs(); ++i_input) {
157 const auto &src_md
158 = query_md(const_pd, DNNL_ARG_MULTIPLE_SRC + i_input);
159 src_fp.emplace_back(src_md, dnnl_f32, tag::abx, ref_engine);
160 src_dt.emplace_back(src_md, test_engine);
161 SAFE(fill_src(i_input, dst_dt.dt(), src_dt[i_input], src_fp[i_input]),
162 WARN);
163 args.set(DNNL_ARG_MULTIPLE_SRC + i_input, src_dt[i_input]);
164 if (is_bench_mode(CORR))
165 ref_args.set(DNNL_ARG_MULTIPLE_SRC + i_input, src_fp[i_input]);
166
167 // scales
168 const auto &sc = prb->attr.scales.get(DNNL_ARG_MULTIPLE_SRC + i_input);
169 float scale_val = sc.scale;
170 maybe_prepare_runtime_scales(scales[i_input], sc, 1, &scale_val);
171 args.set((DNNL_ARG_MULTIPLE_SRC + i_input) | DNNL_ARG_ATTR_SCALES,
172 scales[i_input]);
173 }
174
175 SAFE(execute_and_wait(prim, args, res), WARN);
176
177 if (is_bench_mode(CORR)) {
178 ref_args.set(DNNL_ARG_DST, dst_fp);
179
180 check_correctness(prb, {DST}, args, ref_args, setup_cmp, res);
181 }
182
183 return measure_perf(prb->ctx_exe, res, prim, args);
184}
185
186} // namespace concat
187