1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef REORDER_HPP
18#define REORDER_HPP
19
20#include <iostream>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#include "common.hpp"
25#include "dnn_types.hpp"
26#include "dnnl_common.hpp"
27#include "utils/perf_report.hpp"
28#include "utils/settings.hpp"
29
30namespace reorder {
31
32enum flag_bit_t {
33 FLAG_NONE = 0x0U,
34 FLAG_S8S8_COMP = 0x1U,
35 FLAG_ZP_COMP = 0x2U,
36 FLAG_ANY = ~FLAG_NONE, // For internal use only.
37};
38using flag_t = std::pair<flag_bit_t, int>;
39flag_t str2flag(const char *str);
40std::string flag2str(flag_bit_t flag);
41
42struct dt_conf_s {
43 dnnl_data_type_t dt;
44 float min;
45 float max;
46};
47typedef const dt_conf_s *dt_conf_t;
48dt_conf_t dt2cfg(dnnl_data_type_t dt);
49dnnl_data_type_t cfg2dt(dt_conf_t cfg);
50
51enum cross_engine_t { NONE, CPU2GPU, GPU2CPU };
52cross_engine_t str2cross_engine(const char *str);
53const char *cross_engine2str(cross_engine_t cross_engine);
54
55struct settings_t : public base_settings_t {
56 settings_t() = default;
57
58 // ctor to save certain fields from resetting
59 settings_t(const char *perf_template) : settings_t() {
60 this->perf_template = perf_template;
61 }
62
63 prb_dims_t prb_dims;
64
65 std::vector<dnnl_data_type_t> sdt {dnnl_f32}, ddt {dnnl_f32};
66 std::vector<std::string> stag {tag::abx}, dtag {tag::abx};
67 std::vector<std::vector<flag_t>> oflag {{{FLAG_NONE, 0}}};
68 std::vector<unsigned> runtime_dim_mask {0};
69 std::vector<cross_engine_t> cross_engine {NONE};
70
71 // Just to increase the coverage, doesn't participate in prb construction.
72 std::vector<float> def_scale {0.125, 0.25, 0.5, 1, 2, 4, 8};
73
74 const char *perf_template_csv() const {
75 static const std::string args = "%sdt%,%ddt%,%stag%,%dtag%,%flags%";
76 return perf_template_csv_base(args);
77 }
78
79 void reset() { *this = settings_t(perf_template); }
80};
81
82struct prb_t : public prb_dims_t {
83 prb_t(const prb_dims_t &prb_dims, dnnl_data_type_t sdt,
84 dnnl_data_type_t ddt, const std::string &stag,
85 const std::string &dtag, const attr_t &attr,
86 const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe,
87 const std::vector<flag_t> &oflag, cross_engine_t cross_engine,
88 unsigned runtime_dim_mask)
89 : prb_dims_t(prb_dims)
90 , sdt(sdt)
91 , ddt(ddt)
92 , stag(stag)
93 , dtag(dtag)
94 , attr(attr)
95 , ctx_init(ctx_init)
96 , ctx_exe(ctx_exe)
97 , oflag(oflag)
98 , cross_engine(cross_engine)
99 , runtime_dim_mask(runtime_dim_mask) {
100 src_zp = generate_zero_points(DNNL_ARG_SRC);
101 dst_zp = generate_zero_points(DNNL_ARG_DST);
102 src_scales = generate_scales(DNNL_ARG_SRC);
103 dst_scales = generate_scales(DNNL_ARG_DST);
104 }
105 ~prb_t() {
106 if (src_zp) zfree(src_zp);
107 if (dst_zp) zfree(dst_zp);
108 if (src_scales) zfree(src_scales);
109 if (dst_scales) zfree(dst_scales);
110 }
111
112 dir_t dir = FLAG_FWD; // Lack of prop_kind, always considered as forward.
113 dnnl_data_type_t sdt, ddt;
114 std::string stag, dtag;
115 attr_t attr;
116 thr_ctx_t ctx_init, ctx_exe;
117 std::vector<flag_t> oflag;
118 cross_engine_t cross_engine;
119 unsigned runtime_dim_mask;
120 int32_t *src_zp, *dst_zp;
121 float *src_scales, *dst_scales;
122
123 bool is_reorder_with_compensation(flag_bit_t flag) const;
124 dims_t get_compensation_dims(flag_bit_t flag) const;
125 int get_compensation_mask(flag_bit_t flag) const;
126 int32_t *generate_zero_points(int arg) const;
127 float *generate_scales(int arg) const;
128 dt_conf_t get_conf(data_kind_t kind) const;
129
130private:
131 void get_compensation_parameters(
132 dims_t &comp_dims, int &mask, flag_bit_t flag) const;
133};
134std::ostream &operator<<(std::ostream &s, const prb_t &prb);
135std::ostream &operator<<(std::ostream &s, const std::vector<flag_t> &oflag);
136
137struct perf_report_t : public base_perf_report_t {
138 perf_report_t(const prb_t *prb, const char *perf_template)
139 : base_perf_report_t(perf_template)
140 , p_(prb)
141 , sdt_({p_->sdt})
142 , stag_({normalize_tag(p_->stag, p_->ndims)})
143 , dtag_(normalize_tag(p_->dtag, p_->ndims)) {}
144
145 void dump_desc(std::ostream &s) const override {
146 s << static_cast<const prb_dims_t &>(*p_);
147 }
148
149 void dump_desc_csv(std::ostream &s) const override { dump_desc(s); }
150
151 void dump_engine(std::ostream &s) const override {
152 if (p_->cross_engine == CPU2GPU)
153 s << "cpu2gpu";
154 else if (p_->cross_engine == GPU2CPU)
155 s << "gpu2cpu";
156 else
157 base_perf_report_t::dump_engine(s);
158 }
159
160 void dump_flags(std::ostream &s) const override { s << p_->oflag; }
161
162 const attr_t *attr() const override { return &p_->attr; }
163 const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; }
164 const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; }
165 const std::string *name() const override { return &p_->name; }
166 const std::vector<dnnl_data_type_t> *sdt() const override { return &sdt_; }
167 const dnnl_data_type_t *ddt() const override { return &p_->ddt; }
168 const std::vector<std::string> *stag() const override { return &stag_; }
169 const std::string *dtag() const override { return &dtag_; }
170
171private:
172 const prb_t *p_;
173 std::vector<dnnl_data_type_t> sdt_;
174 std::vector<std::string> stag_;
175 std::string dtag_;
176};
177
178void skip_unimplemented_prb(const prb_t *prb, res_t *res);
179void skip_invalid_prb(const prb_t *prb, res_t *res);
180void compute_ref(const prb_t *prb, const args_t &args,
181 dnnl_primitive_t prim_ref = nullptr);
182
183int doit(const prb_t *prb, res_t *res);
184int bench(int argc, char **argv);
185
186} // namespace reorder
187
188#endif
189