1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <stdio.h> |
18 | #include <stdlib.h> |
19 | |
20 | #include <random> |
21 | |
22 | #include "oneapi/dnnl/dnnl.h" |
23 | |
24 | #include "utils/parallel.hpp" |
25 | |
26 | #include "dnnl_common.hpp" |
27 | #include "dnnl_memory.hpp" |
28 | |
29 | #include "concat/concat.hpp" |
30 | |
31 | namespace concat { |
32 | |
33 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
34 | const prb_t *prb = init_pd_args.prb; |
35 | |
36 | std::vector<benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t>> src_d_wrappers( |
37 | prb->n_inputs()); |
38 | |
39 | for (int i_input = 0; i_input < prb->n_inputs(); ++i_input) { |
40 | const dims_t &i_vdims = prb->vdims[i_input]; |
41 | src_d_wrappers[i_input] = dnn_mem_t::init_md( |
42 | prb->ndims, i_vdims.data(), prb->sdt, prb->stag[i_input]); |
43 | } |
44 | |
45 | benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> dst_d {}; |
46 | if (prb->dtag != tag::undef) { |
47 | dst_d = dnn_mem_t::init_md( |
48 | prb->ndims, prb->dst_dims.data(), prb->ddt, prb->dtag); |
49 | } |
50 | |
51 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
52 | create_dnnl_attr(prb->attr, attr_args_t())); |
53 | |
54 | std::vector<dnnl_memory_desc_t> src_d( |
55 | src_d_wrappers.begin(), src_d_wrappers.end()); |
56 | init_pd_args.is_iterator_supported = false; |
57 | DNN_SAFE_STATUS(dnnl_concat_primitive_desc_create(&init_pd_args.pd, |
58 | init_pd_args.engine, dst_d, prb->n_inputs(), prb->axis, |
59 | src_d.data(), dnnl_attr)); |
60 | |
61 | return dnnl_success; |
62 | } |
63 | |
64 | int fill_src(int input_idx, dnnl_data_type_t dt, dnn_mem_t &mem_dt, |
65 | dnn_mem_t &mem_fp) { |
66 | const auto nelems = mem_fp.nelems(); |
67 | // Do fixed partitioning to have same filling for any number of threads. |
68 | const int64_t n_chunks = 16; |
69 | const int64_t chunk_size = div_up(nelems, n_chunks); |
70 | // Set proper range of valid values to avoid any reorders back and forth. |
71 | const bool s8u8_or_u8s8 = (dt == dnnl_s8 && mem_dt.dt() == dnnl_u8) |
72 | || (dt == dnnl_u8 && mem_dt.dt() == dnnl_s8); |
73 | float min_val = lowest_dt(dnnl_s8); |
74 | float max_val = max_dt(dnnl_u8); |
75 | if (s8u8_or_u8s8) { |
76 | min_val = lowest_dt(dnnl_u8); |
77 | max_val = max_dt(dnnl_s8); |
78 | } else if (dt == dnnl_s8 || mem_dt.dt() == dnnl_s8) { |
79 | max_val = max_dt(dnnl_s8); |
80 | } else if (dt == dnnl_u8 || mem_dt.dt() == dnnl_u8) { |
81 | min_val = lowest_dt(dnnl_u8); |
82 | } |
83 | |
84 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) { |
85 | int64_t idx_start = idx_chunk * chunk_size; |
86 | int64_t idx_end = MIN2(idx_start + chunk_size, nelems); |
87 | // See eltwise.cpp for implementation details. |
88 | std::minstd_rand msr(input_idx * n_chunks + idx_start + 1); |
89 | msr.discard(1); |
90 | std::uniform_int_distribution<> igen(min_val, max_val); |
91 | // No need to round final value as it's already in needed dt. |
92 | for (int64_t idx = idx_start; idx < idx_end; ++idx) |
93 | mem_fp.set_elem(idx, (float)igen(msr)); |
94 | }); |
95 | |
96 | SAFE(mem_dt.reorder(mem_fp), WARN); |
97 | |
98 | return OK; |
99 | } |
100 | |
101 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
102 | skip_unimplemented_data_type({prb->sdt, prb->ddt}, prb->dir, res); |
103 | skip_unimplemented_sum_po(prb->attr, res); |
104 | skip_unimplemented_arg_scale(prb->attr, res); |
105 | |
106 | // ref concat is reorder-based, hence, inherits some reorder limitations. |
107 | // bf16, f16 reorders on cpu supports only [bf16, f16]<->f32 |
108 | bool valid_xf16_input |
109 | = IMPLICATION(prb->sdt == dnnl_bf16 || prb->sdt == dnnl_f16, |
110 | prb->dtag == tag::undef || prb->ddt == dnnl_f32 |
111 | || prb->ddt == prb->sdt); |
112 | bool valid_xf16_output |
113 | = IMPLICATION((prb->ddt == dnnl_bf16 || prb->ddt == dnnl_f16) |
114 | && prb->dtag != tag::undef, |
115 | (prb->sdt == dnnl_f32 || prb->sdt == prb->ddt)); |
116 | |
117 | if (is_cpu() && (!valid_xf16_input || !valid_xf16_output)) { |
118 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
119 | return; |
120 | } |
121 | } |
122 | |
123 | void skip_invalid_prb(const prb_t *prb, res_t *res) {} |
124 | |
125 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
126 | const args_t &ref_args) {} |
127 | |
128 | int doit(const prb_t *prb, res_t *res) { |
129 | if (bench_mode == LIST) return res->state = LISTED, OK; |
130 | |
131 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
132 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN); |
133 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
134 | |
135 | auto const_pd = query_pd(prim); |
136 | |
137 | const auto &test_engine = get_test_engine(); |
138 | const auto &ref_engine = get_cpu_engine(); |
139 | const auto &dst_md = query_md(const_pd, DNNL_ARG_DST); |
140 | const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD); |
141 | |
142 | dnn_mem_t dst_fp(dst_md, dnnl_f32, tag::abx, ref_engine); |
143 | dnn_mem_t dst_dt(dst_md, test_engine); |
144 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
145 | |
146 | args_t args, ref_args; |
147 | |
148 | args.set(DNNL_ARG_DST, dst_dt); |
149 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
150 | |
151 | std::vector<dnn_mem_t> src_fp, src_dt, scales; |
152 | src_fp.reserve(prb->n_inputs()); |
153 | src_dt.reserve(prb->n_inputs()); |
154 | scales.resize(prb->n_inputs()); |
155 | |
156 | for (int i_input = 0; i_input < prb->n_inputs(); ++i_input) { |
157 | const auto &src_md |
158 | = query_md(const_pd, DNNL_ARG_MULTIPLE_SRC + i_input); |
159 | src_fp.emplace_back(src_md, dnnl_f32, tag::abx, ref_engine); |
160 | src_dt.emplace_back(src_md, test_engine); |
161 | SAFE(fill_src(i_input, dst_dt.dt(), src_dt[i_input], src_fp[i_input]), |
162 | WARN); |
163 | args.set(DNNL_ARG_MULTIPLE_SRC + i_input, src_dt[i_input]); |
164 | if (is_bench_mode(CORR)) |
165 | ref_args.set(DNNL_ARG_MULTIPLE_SRC + i_input, src_fp[i_input]); |
166 | |
167 | // scales |
168 | const auto &sc = prb->attr.scales.get(DNNL_ARG_MULTIPLE_SRC + i_input); |
169 | float scale_val = sc.scale; |
170 | maybe_prepare_runtime_scales(scales[i_input], sc, 1, &scale_val); |
171 | args.set((DNNL_ARG_MULTIPLE_SRC + i_input) | DNNL_ARG_ATTR_SCALES, |
172 | scales[i_input]); |
173 | } |
174 | |
175 | SAFE(execute_and_wait(prim, args, res), WARN); |
176 | |
177 | if (is_bench_mode(CORR)) { |
178 | ref_args.set(DNNL_ARG_DST, dst_fp); |
179 | |
180 | check_correctness(prb, {DST}, args, ref_args, setup_cmp, res); |
181 | } |
182 | |
183 | return measure_perf(prb->ctx_exe, res, prim, args); |
184 | } |
185 | |
186 | } // namespace concat |
187 | |