1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <stddef.h> |
18 | #include <stdio.h> |
19 | #include <stdlib.h> |
20 | |
21 | #include "oneapi/dnnl/dnnl.h" |
22 | |
23 | #include "utils/parallel.hpp" |
24 | |
25 | #include "dnnl_common.hpp" |
26 | #include "dnnl_memory.hpp" |
27 | |
28 | #include "shuffle/shuffle.hpp" |
29 | |
30 | namespace shuffle { |
31 | |
32 | int fill_src(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) { |
33 | auto get_range = [](const dnnl_data_type_t dt) { |
34 | if (dt == dnnl_s8 || dt == dnnl_u8) |
35 | return 256; |
36 | else if (dt == dnnl_bf16 || dt == dnnl_f16) |
37 | return 128; |
38 | return 1024; |
39 | }; |
40 | |
41 | const auto nelems = mem_fp.nelems(); |
42 | const int range = get_range(prb->dt); |
43 | const int f_min = prb->dt == dnnl_u8 ? 0 : -range / 2; |
44 | |
45 | benchdnn_parallel_nd(nelems, [&](int64_t i) { |
46 | const float gen = ((97 * i) + 101) % range; |
47 | const float value = (prb->dt == dnnl_bf16 || prb->dt == dnnl_f16) |
48 | ? (f_min + gen) / range |
49 | : (f_min + gen) * (1.0f + 4.0f / range); |
50 | mem_fp.set_elem(i, round_to_nearest_representable(prb->dt, value)); |
51 | }); |
52 | |
53 | SAFE(mem_dt.reorder(mem_fp), WARN); |
54 | |
55 | return OK; |
56 | } |
57 | |
58 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
59 | const prb_t *prb = init_pd_args.prb; |
60 | |
61 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
62 | create_dnnl_attr(prb->attr, attr_args_t())); |
63 | |
64 | if (prb->dir & FLAG_FWD) { |
65 | auto src_d = dnn_mem_t::init_md( |
66 | prb->ndims, prb->dims.data(), prb->dt, prb->tag); |
67 | auto dst_d = dnn_mem_t::init_md( |
68 | prb->ndims, prb->dims.data(), prb->dt, tag::any); |
69 | |
70 | auto prop_kind = prb->dir & FLAG_INF ? dnnl_forward_inference |
71 | : dnnl_forward_training; |
72 | DNN_SAFE_STATUS(dnnl_shuffle_forward_primitive_desc_create( |
73 | &init_pd_args.pd, init_pd_args.engine, prop_kind, src_d, dst_d, |
74 | prb->axis, prb->group, dnnl_attr)); |
75 | } else { |
76 | auto diff_src_d = dnn_mem_t::init_md( |
77 | prb->ndims, prb->dims.data(), prb->dt, tag::any); |
78 | auto diff_dst_d = dnn_mem_t::init_md( |
79 | prb->ndims, prb->dims.data(), prb->dt, tag::any); |
80 | |
81 | DNN_SAFE_STATUS(dnnl_shuffle_backward_primitive_desc_create( |
82 | &init_pd_args.pd, init_pd_args.engine, diff_src_d, diff_dst_d, |
83 | prb->axis, prb->group, init_pd_args.hint, dnnl_attr)); |
84 | } |
85 | return dnnl_success; |
86 | } |
87 | |
88 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
89 | skip_unimplemented_data_type({prb->dt}, prb->dir, res); |
90 | skip_unimplemented_sum_po(prb->attr, res); |
91 | } |
92 | |
93 | void skip_invalid_prb(const prb_t *prb, res_t *res) {} |
94 | |
95 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
96 | const args_t &ref_args) {} |
97 | |
98 | int doit(const prb_t *prb, res_t *res) { |
99 | if (bench_mode == LIST) return res->state = LISTED, OK; |
100 | |
101 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
102 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN); |
103 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
104 | |
105 | auto const_pd = query_pd(prim); |
106 | |
107 | const auto &data_md = prb->dir & FLAG_FWD |
108 | ? query_md(const_pd, DNNL_ARG_SRC) |
109 | : query_md(const_pd, DNNL_ARG_DIFF_SRC); |
110 | const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD); |
111 | const auto &test_engine = get_test_engine(); |
112 | const auto &ref_engine = get_cpu_engine(); |
113 | |
114 | dnn_mem_t src_fp(data_md, dnnl_f32, tag::abx, ref_engine); |
115 | dnn_mem_t src_dt(data_md, test_engine); |
116 | |
117 | dnn_mem_t dst_fp(data_md, dnnl_f32, tag::abx, ref_engine); |
118 | dnn_mem_t dst_dt(data_md, test_engine); |
119 | |
120 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
121 | |
122 | args_t args, ref_args; |
123 | |
124 | if (prb->dir & FLAG_FWD) { |
125 | SAFE(fill_src(prb, src_dt, src_fp), WARN); |
126 | |
127 | args.set(DNNL_ARG_SRC, src_dt); |
128 | args.set(DNNL_ARG_DST, dst_dt); |
129 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
130 | |
131 | SAFE(execute_and_wait(prim, args, res), WARN); |
132 | |
133 | if (is_bench_mode(CORR)) { |
134 | ref_args.set(DNNL_ARG_SRC, src_fp); |
135 | ref_args.set(DNNL_ARG_DST, dst_fp); |
136 | |
137 | check_correctness(prb, {DST}, args, ref_args, setup_cmp, res); |
138 | } |
139 | } else { |
140 | SAFE(fill_src(prb, dst_dt, dst_fp), WARN); |
141 | |
142 | args.set(DNNL_ARG_DIFF_DST, dst_dt); |
143 | args.set(DNNL_ARG_DIFF_SRC, src_dt); |
144 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
145 | |
146 | SAFE(execute_and_wait(prim, args, res), WARN); |
147 | |
148 | if (is_bench_mode(CORR)) { |
149 | ref_args.set(DNNL_ARG_DIFF_DST, dst_fp); |
150 | ref_args.set(DNNL_ARG_DIFF_SRC, src_fp); |
151 | |
152 | check_correctness(prb, {SRC}, args, ref_args, setup_cmp, res); |
153 | } |
154 | } |
155 | |
156 | return measure_perf(prb->ctx_exe, res, prim, args); |
157 | } |
158 | |
159 | } // namespace shuffle |
160 | |