1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <stddef.h>
18#include <stdio.h>
19#include <stdlib.h>
20
21#include "oneapi/dnnl/dnnl.h"
22
23#include "utils/parallel.hpp"
24
25#include "dnnl_common.hpp"
26#include "dnnl_memory.hpp"
27
28#include "shuffle/shuffle.hpp"
29
30namespace shuffle {
31
32int fill_src(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) {
33 auto get_range = [](const dnnl_data_type_t dt) {
34 if (dt == dnnl_s8 || dt == dnnl_u8)
35 return 256;
36 else if (dt == dnnl_bf16 || dt == dnnl_f16)
37 return 128;
38 return 1024;
39 };
40
41 const auto nelems = mem_fp.nelems();
42 const int range = get_range(prb->dt);
43 const int f_min = prb->dt == dnnl_u8 ? 0 : -range / 2;
44
45 benchdnn_parallel_nd(nelems, [&](int64_t i) {
46 const float gen = ((97 * i) + 101) % range;
47 const float value = (prb->dt == dnnl_bf16 || prb->dt == dnnl_f16)
48 ? (f_min + gen) / range
49 : (f_min + gen) * (1.0f + 4.0f / range);
50 mem_fp.set_elem(i, round_to_nearest_representable(prb->dt, value));
51 });
52
53 SAFE(mem_dt.reorder(mem_fp), WARN);
54
55 return OK;
56}
57
58dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
59 const prb_t *prb = init_pd_args.prb;
60
61 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
62 create_dnnl_attr(prb->attr, attr_args_t()));
63
64 if (prb->dir & FLAG_FWD) {
65 auto src_d = dnn_mem_t::init_md(
66 prb->ndims, prb->dims.data(), prb->dt, prb->tag);
67 auto dst_d = dnn_mem_t::init_md(
68 prb->ndims, prb->dims.data(), prb->dt, tag::any);
69
70 auto prop_kind = prb->dir & FLAG_INF ? dnnl_forward_inference
71 : dnnl_forward_training;
72 DNN_SAFE_STATUS(dnnl_shuffle_forward_primitive_desc_create(
73 &init_pd_args.pd, init_pd_args.engine, prop_kind, src_d, dst_d,
74 prb->axis, prb->group, dnnl_attr));
75 } else {
76 auto diff_src_d = dnn_mem_t::init_md(
77 prb->ndims, prb->dims.data(), prb->dt, tag::any);
78 auto diff_dst_d = dnn_mem_t::init_md(
79 prb->ndims, prb->dims.data(), prb->dt, tag::any);
80
81 DNN_SAFE_STATUS(dnnl_shuffle_backward_primitive_desc_create(
82 &init_pd_args.pd, init_pd_args.engine, diff_src_d, diff_dst_d,
83 prb->axis, prb->group, init_pd_args.hint, dnnl_attr));
84 }
85 return dnnl_success;
86}
87
88void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
89 skip_unimplemented_data_type({prb->dt}, prb->dir, res);
90 skip_unimplemented_sum_po(prb->attr, res);
91}
92
93void skip_invalid_prb(const prb_t *prb, res_t *res) {}
94
95void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
96 const args_t &ref_args) {}
97
98int doit(const prb_t *prb, res_t *res) {
99 if (bench_mode == LIST) return res->state = LISTED, OK;
100
101 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
102 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN);
103 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
104
105 auto const_pd = query_pd(prim);
106
107 const auto &data_md = prb->dir & FLAG_FWD
108 ? query_md(const_pd, DNNL_ARG_SRC)
109 : query_md(const_pd, DNNL_ARG_DIFF_SRC);
110 const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD);
111 const auto &test_engine = get_test_engine();
112 const auto &ref_engine = get_cpu_engine();
113
114 dnn_mem_t src_fp(data_md, dnnl_f32, tag::abx, ref_engine);
115 dnn_mem_t src_dt(data_md, test_engine);
116
117 dnn_mem_t dst_fp(data_md, dnnl_f32, tag::abx, ref_engine);
118 dnn_mem_t dst_dt(data_md, test_engine);
119
120 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
121
122 args_t args, ref_args;
123
124 if (prb->dir & FLAG_FWD) {
125 SAFE(fill_src(prb, src_dt, src_fp), WARN);
126
127 args.set(DNNL_ARG_SRC, src_dt);
128 args.set(DNNL_ARG_DST, dst_dt);
129 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
130
131 SAFE(execute_and_wait(prim, args, res), WARN);
132
133 if (is_bench_mode(CORR)) {
134 ref_args.set(DNNL_ARG_SRC, src_fp);
135 ref_args.set(DNNL_ARG_DST, dst_fp);
136
137 check_correctness(prb, {DST}, args, ref_args, setup_cmp, res);
138 }
139 } else {
140 SAFE(fill_src(prb, dst_dt, dst_fp), WARN);
141
142 args.set(DNNL_ARG_DIFF_DST, dst_dt);
143 args.set(DNNL_ARG_DIFF_SRC, src_dt);
144 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
145
146 SAFE(execute_and_wait(prim, args, res), WARN);
147
148 if (is_bench_mode(CORR)) {
149 ref_args.set(DNNL_ARG_DIFF_DST, dst_fp);
150 ref_args.set(DNNL_ARG_DIFF_SRC, src_fp);
151
152 check_correctness(prb, {SRC}, args, ref_args, setup_cmp, res);
153 }
154 }
155
156 return measure_perf(prb->ctx_exe, res, prim, args);
157}
158
159} // namespace shuffle
160