1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef REDUCTION_HPP |
18 | #define REDUCTION_HPP |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | #include "dnn_types.hpp" |
23 | #include "dnnl_common.hpp" |
24 | #include "utils/perf_report.hpp" |
25 | #include "utils/settings.hpp" |
26 | |
27 | namespace reduction { |
28 | |
29 | enum alg_t { |
30 | undef, |
31 | min, |
32 | max, |
33 | mul, |
34 | sum, |
35 | mean, |
36 | norm_lp_max, |
37 | norm_lp_sum, |
38 | norm_lp_power_p_max, |
39 | norm_lp_power_p_sum, |
40 | reduction_min = min, |
41 | reduction_max = max, |
42 | reduction_mul = mul, |
43 | reduction_sum = sum, |
44 | reduction_mean = mean, |
45 | reduction_norm_lp_max = norm_lp_max, |
46 | reduction_norm_lp_sum = norm_lp_sum, |
47 | reduction_norm_lp_power_p_max = norm_lp_power_p_max, |
48 | reduction_norm_lp_power_p_sum = norm_lp_power_p_sum, |
49 | }; |
50 | |
51 | alg_t str2alg(const char *str); |
52 | const char *alg2str(alg_t alg); |
53 | dnnl_alg_kind_t alg2alg_kind(alg_t alg); |
54 | |
55 | struct settings_t : public base_settings_t { |
56 | settings_t() = default; |
57 | |
58 | // ctor to save certain fields from resetting |
59 | settings_t(const char *perf_template) : settings_t() { |
60 | this->perf_template = perf_template; |
61 | } |
62 | |
63 | prb_vdims_t prb_vdims; |
64 | |
65 | std::vector<dnnl_data_type_t> sdt {dnnl_f32}; |
66 | std::vector<dnnl_data_type_t> ddt {dnnl_f32}; |
67 | std::vector<std::string> stag {tag::abx}; |
68 | std::vector<std::string> dtag {tag::any}; |
69 | std::vector<alg_t> alg {alg_t::sum}; |
70 | std::vector<float> p {1.0f}, eps {0.0f}; |
71 | |
72 | const char *perf_template_csv() const { |
73 | static const std::string args = "%sdt%,%ddt%,%stag%,%dtag%,%alg%" ; |
74 | return perf_template_csv_base(args); |
75 | } |
76 | |
77 | void reset() { *this = settings_t(perf_template); } |
78 | }; |
79 | |
80 | struct prb_t : public prb_vdims_t { |
81 | prb_t(const prb_vdims_t &prb_vdims, dnnl_data_type_t sdt, |
82 | dnnl_data_type_t ddt, const std::string &stag, |
83 | const std::string &dtag, alg_t alg, float p, float eps, |
84 | const attr_t &attr, const thr_ctx_t &ctx_init, |
85 | const thr_ctx_t &ctx_exe) |
86 | : prb_vdims_t(prb_vdims) |
87 | , sdt(sdt) |
88 | , ddt(ddt) |
89 | , stag(stag) |
90 | , dtag(dtag) |
91 | , alg(alg) |
92 | , p(p) |
93 | , eps(eps) |
94 | , attr(attr) |
95 | , ctx_init(ctx_init) |
96 | , ctx_exe(ctx_exe) {} |
97 | |
98 | dir_t dir = FLAG_FWD; // Lack of prop_kind, always considered as forward. |
99 | dnnl_data_type_t sdt, ddt; |
100 | std::string stag, dtag; |
101 | alg_t alg; |
102 | float p, eps; |
103 | attr_t attr; |
104 | thr_ctx_t ctx_init, ctx_exe; |
105 | }; |
106 | |
107 | std::ostream &operator<<(std::ostream &s, const prb_t &prb); |
108 | |
109 | struct perf_report_t : public base_perf_report_t { |
110 | perf_report_t(const prb_t *prb, const char *perf_template) |
111 | : base_perf_report_t(perf_template) |
112 | , prb_(prb) |
113 | , sdt_({prb_->sdt}) |
114 | , stag_({normalize_tag(prb_->stag, prb_->ndims)}) |
115 | , dtag_(normalize_tag(prb_->dtag, prb_->ndims)) {} |
116 | |
117 | void dump_alg(std::ostream &s) const override { s << alg2str(prb_->alg); } |
118 | |
119 | void dump_desc(std::ostream &s) const override { |
120 | s << static_cast<const prb_vdims_t &>(*prb_); |
121 | } |
122 | |
123 | void dump_desc_csv(std::ostream &s) const override { dump_desc(s); } |
124 | |
125 | const attr_t *attr() const override { return &prb_->attr; } |
126 | const thr_ctx_t *ctx_init() const override { return &prb_->ctx_init; } |
127 | const thr_ctx_t *ctx_exe() const override { return &prb_->ctx_exe; } |
128 | const std::string *name() const override { return &prb_->name; } |
129 | const std::vector<dnnl_data_type_t> *sdt() const override { return &sdt_; } |
130 | const dnnl_data_type_t *ddt() const override { return &prb_->ddt; } |
131 | const std::vector<std::string> *stag() const override { return &stag_; } |
132 | const std::string *dtag() const override { return &dtag_; } |
133 | |
134 | private: |
135 | const prb_t *prb_; |
136 | std::vector<dnnl_data_type_t> sdt_; |
137 | std::vector<std::string> stag_; |
138 | std::string dtag_; |
139 | }; |
140 | |
141 | void skip_unimplemented_prb(const prb_t *prb, res_t *res); |
142 | void skip_invalid_prb(const prb_t *prb, res_t *res); |
143 | void compute_ref(const prb_t *prb, const args_t &args, |
144 | dnnl_primitive_t prim_ref = nullptr); |
145 | |
146 | int doit(const prb_t *prb, res_t *res); |
147 | int bench(int argc, char **argv); |
148 | |
149 | } // namespace reduction |
150 | |
151 | #endif |
152 | |