1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef REORDER_HPP |
18 | #define REORDER_HPP |
19 | |
20 | #include <iostream> |
21 | |
22 | #include "oneapi/dnnl/dnnl.h" |
23 | |
24 | #include "common.hpp" |
25 | #include "dnn_types.hpp" |
26 | #include "dnnl_common.hpp" |
27 | #include "utils/perf_report.hpp" |
28 | #include "utils/settings.hpp" |
29 | |
30 | namespace reorder { |
31 | |
32 | enum flag_bit_t { |
33 | FLAG_NONE = 0x0U, |
34 | FLAG_S8S8_COMP = 0x1U, |
35 | FLAG_ZP_COMP = 0x2U, |
36 | FLAG_ANY = ~FLAG_NONE, // For internal use only. |
37 | }; |
38 | using flag_t = std::pair<flag_bit_t, int>; |
39 | flag_t str2flag(const char *str); |
40 | std::string flag2str(flag_bit_t flag); |
41 | |
42 | struct dt_conf_s { |
43 | dnnl_data_type_t dt; |
44 | float min; |
45 | float max; |
46 | }; |
47 | typedef const dt_conf_s *dt_conf_t; |
48 | dt_conf_t dt2cfg(dnnl_data_type_t dt); |
49 | dnnl_data_type_t cfg2dt(dt_conf_t cfg); |
50 | |
51 | enum cross_engine_t { NONE, CPU2GPU, GPU2CPU }; |
52 | cross_engine_t str2cross_engine(const char *str); |
53 | const char *cross_engine2str(cross_engine_t cross_engine); |
54 | |
55 | struct settings_t : public base_settings_t { |
56 | settings_t() = default; |
57 | |
58 | // ctor to save certain fields from resetting |
59 | settings_t(const char *perf_template) : settings_t() { |
60 | this->perf_template = perf_template; |
61 | } |
62 | |
63 | prb_dims_t prb_dims; |
64 | |
65 | std::vector<dnnl_data_type_t> sdt {dnnl_f32}, ddt {dnnl_f32}; |
66 | std::vector<std::string> stag {tag::abx}, dtag {tag::abx}; |
67 | std::vector<std::vector<flag_t>> oflag {{{FLAG_NONE, 0}}}; |
68 | std::vector<unsigned> runtime_dim_mask {0}; |
69 | std::vector<cross_engine_t> cross_engine {NONE}; |
70 | |
71 | // Just to increase the coverage, doesn't participate in prb construction. |
72 | std::vector<float> def_scale {0.125, 0.25, 0.5, 1, 2, 4, 8}; |
73 | |
74 | const char *perf_template_csv() const { |
75 | static const std::string args = "%sdt%,%ddt%,%stag%,%dtag%,%flags%" ; |
76 | return perf_template_csv_base(args); |
77 | } |
78 | |
79 | void reset() { *this = settings_t(perf_template); } |
80 | }; |
81 | |
82 | struct prb_t : public prb_dims_t { |
83 | prb_t(const prb_dims_t &prb_dims, dnnl_data_type_t sdt, |
84 | dnnl_data_type_t ddt, const std::string &stag, |
85 | const std::string &dtag, const attr_t &attr, |
86 | const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe, |
87 | const std::vector<flag_t> &oflag, cross_engine_t cross_engine, |
88 | unsigned runtime_dim_mask) |
89 | : prb_dims_t(prb_dims) |
90 | , sdt(sdt) |
91 | , ddt(ddt) |
92 | , stag(stag) |
93 | , dtag(dtag) |
94 | , attr(attr) |
95 | , ctx_init(ctx_init) |
96 | , ctx_exe(ctx_exe) |
97 | , oflag(oflag) |
98 | , cross_engine(cross_engine) |
99 | , runtime_dim_mask(runtime_dim_mask) { |
100 | src_zp = generate_zero_points(DNNL_ARG_SRC); |
101 | dst_zp = generate_zero_points(DNNL_ARG_DST); |
102 | src_scales = generate_scales(DNNL_ARG_SRC); |
103 | dst_scales = generate_scales(DNNL_ARG_DST); |
104 | } |
105 | ~prb_t() { |
106 | if (src_zp) zfree(src_zp); |
107 | if (dst_zp) zfree(dst_zp); |
108 | if (src_scales) zfree(src_scales); |
109 | if (dst_scales) zfree(dst_scales); |
110 | } |
111 | |
112 | dir_t dir = FLAG_FWD; // Lack of prop_kind, always considered as forward. |
113 | dnnl_data_type_t sdt, ddt; |
114 | std::string stag, dtag; |
115 | attr_t attr; |
116 | thr_ctx_t ctx_init, ctx_exe; |
117 | std::vector<flag_t> oflag; |
118 | cross_engine_t cross_engine; |
119 | unsigned runtime_dim_mask; |
120 | int32_t *src_zp, *dst_zp; |
121 | float *src_scales, *dst_scales; |
122 | |
123 | bool is_reorder_with_compensation(flag_bit_t flag) const; |
124 | dims_t get_compensation_dims(flag_bit_t flag) const; |
125 | int get_compensation_mask(flag_bit_t flag) const; |
126 | int32_t *generate_zero_points(int arg) const; |
127 | float *generate_scales(int arg) const; |
128 | dt_conf_t get_conf(data_kind_t kind) const; |
129 | |
130 | private: |
131 | void get_compensation_parameters( |
132 | dims_t &comp_dims, int &mask, flag_bit_t flag) const; |
133 | }; |
134 | std::ostream &operator<<(std::ostream &s, const prb_t &prb); |
135 | std::ostream &operator<<(std::ostream &s, const std::vector<flag_t> &oflag); |
136 | |
137 | struct perf_report_t : public base_perf_report_t { |
138 | perf_report_t(const prb_t *prb, const char *perf_template) |
139 | : base_perf_report_t(perf_template) |
140 | , p_(prb) |
141 | , sdt_({p_->sdt}) |
142 | , stag_({normalize_tag(p_->stag, p_->ndims)}) |
143 | , dtag_(normalize_tag(p_->dtag, p_->ndims)) {} |
144 | |
145 | void dump_desc(std::ostream &s) const override { |
146 | s << static_cast<const prb_dims_t &>(*p_); |
147 | } |
148 | |
149 | void dump_desc_csv(std::ostream &s) const override { dump_desc(s); } |
150 | |
151 | void dump_engine(std::ostream &s) const override { |
152 | if (p_->cross_engine == CPU2GPU) |
153 | s << "cpu2gpu" ; |
154 | else if (p_->cross_engine == GPU2CPU) |
155 | s << "gpu2cpu" ; |
156 | else |
157 | base_perf_report_t::dump_engine(s); |
158 | } |
159 | |
160 | void dump_flags(std::ostream &s) const override { s << p_->oflag; } |
161 | |
162 | const attr_t *attr() const override { return &p_->attr; } |
163 | const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; } |
164 | const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; } |
165 | const std::string *name() const override { return &p_->name; } |
166 | const std::vector<dnnl_data_type_t> *sdt() const override { return &sdt_; } |
167 | const dnnl_data_type_t *ddt() const override { return &p_->ddt; } |
168 | const std::vector<std::string> *stag() const override { return &stag_; } |
169 | const std::string *dtag() const override { return &dtag_; } |
170 | |
171 | private: |
172 | const prb_t *p_; |
173 | std::vector<dnnl_data_type_t> sdt_; |
174 | std::vector<std::string> stag_; |
175 | std::string dtag_; |
176 | }; |
177 | |
178 | void skip_unimplemented_prb(const prb_t *prb, res_t *res); |
179 | void skip_invalid_prb(const prb_t *prb, res_t *res); |
180 | void compute_ref(const prb_t *prb, const args_t &args, |
181 | dnnl_primitive_t prim_ref = nullptr); |
182 | |
183 | int doit(const prb_t *prb, res_t *res); |
184 | int bench(int argc, char **argv); |
185 | |
186 | } // namespace reorder |
187 | |
188 | #endif |
189 | |