1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef REDUCTION_HPP
18#define REDUCTION_HPP
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22#include "dnn_types.hpp"
23#include "dnnl_common.hpp"
24#include "utils/perf_report.hpp"
25#include "utils/settings.hpp"
26
27namespace reduction {
28
29enum alg_t {
30 undef,
31 min,
32 max,
33 mul,
34 sum,
35 mean,
36 norm_lp_max,
37 norm_lp_sum,
38 norm_lp_power_p_max,
39 norm_lp_power_p_sum,
40 reduction_min = min,
41 reduction_max = max,
42 reduction_mul = mul,
43 reduction_sum = sum,
44 reduction_mean = mean,
45 reduction_norm_lp_max = norm_lp_max,
46 reduction_norm_lp_sum = norm_lp_sum,
47 reduction_norm_lp_power_p_max = norm_lp_power_p_max,
48 reduction_norm_lp_power_p_sum = norm_lp_power_p_sum,
49};
50
51alg_t str2alg(const char *str);
52const char *alg2str(alg_t alg);
53dnnl_alg_kind_t alg2alg_kind(alg_t alg);
54
55struct settings_t : public base_settings_t {
56 settings_t() = default;
57
58 // ctor to save certain fields from resetting
59 settings_t(const char *perf_template) : settings_t() {
60 this->perf_template = perf_template;
61 }
62
63 prb_vdims_t prb_vdims;
64
65 std::vector<dnnl_data_type_t> sdt {dnnl_f32};
66 std::vector<dnnl_data_type_t> ddt {dnnl_f32};
67 std::vector<std::string> stag {tag::abx};
68 std::vector<std::string> dtag {tag::any};
69 std::vector<alg_t> alg {alg_t::sum};
70 std::vector<float> p {1.0f}, eps {0.0f};
71
72 const char *perf_template_csv() const {
73 static const std::string args = "%sdt%,%ddt%,%stag%,%dtag%,%alg%";
74 return perf_template_csv_base(args);
75 }
76
77 void reset() { *this = settings_t(perf_template); }
78};
79
80struct prb_t : public prb_vdims_t {
81 prb_t(const prb_vdims_t &prb_vdims, dnnl_data_type_t sdt,
82 dnnl_data_type_t ddt, const std::string &stag,
83 const std::string &dtag, alg_t alg, float p, float eps,
84 const attr_t &attr, const thr_ctx_t &ctx_init,
85 const thr_ctx_t &ctx_exe)
86 : prb_vdims_t(prb_vdims)
87 , sdt(sdt)
88 , ddt(ddt)
89 , stag(stag)
90 , dtag(dtag)
91 , alg(alg)
92 , p(p)
93 , eps(eps)
94 , attr(attr)
95 , ctx_init(ctx_init)
96 , ctx_exe(ctx_exe) {}
97
98 dir_t dir = FLAG_FWD; // Lack of prop_kind, always considered as forward.
99 dnnl_data_type_t sdt, ddt;
100 std::string stag, dtag;
101 alg_t alg;
102 float p, eps;
103 attr_t attr;
104 thr_ctx_t ctx_init, ctx_exe;
105};
106
107std::ostream &operator<<(std::ostream &s, const prb_t &prb);
108
109struct perf_report_t : public base_perf_report_t {
110 perf_report_t(const prb_t *prb, const char *perf_template)
111 : base_perf_report_t(perf_template)
112 , prb_(prb)
113 , sdt_({prb_->sdt})
114 , stag_({normalize_tag(prb_->stag, prb_->ndims)})
115 , dtag_(normalize_tag(prb_->dtag, prb_->ndims)) {}
116
117 void dump_alg(std::ostream &s) const override { s << alg2str(prb_->alg); }
118
119 void dump_desc(std::ostream &s) const override {
120 s << static_cast<const prb_vdims_t &>(*prb_);
121 }
122
123 void dump_desc_csv(std::ostream &s) const override { dump_desc(s); }
124
125 const attr_t *attr() const override { return &prb_->attr; }
126 const thr_ctx_t *ctx_init() const override { return &prb_->ctx_init; }
127 const thr_ctx_t *ctx_exe() const override { return &prb_->ctx_exe; }
128 const std::string *name() const override { return &prb_->name; }
129 const std::vector<dnnl_data_type_t> *sdt() const override { return &sdt_; }
130 const dnnl_data_type_t *ddt() const override { return &prb_->ddt; }
131 const std::vector<std::string> *stag() const override { return &stag_; }
132 const std::string *dtag() const override { return &dtag_; }
133
134private:
135 const prb_t *prb_;
136 std::vector<dnnl_data_type_t> sdt_;
137 std::vector<std::string> stag_;
138 std::string dtag_;
139};
140
141void skip_unimplemented_prb(const prb_t *prb, res_t *res);
142void skip_invalid_prb(const prb_t *prb, res_t *res);
143void compute_ref(const prb_t *prb, const args_t &args,
144 dnnl_primitive_t prim_ref = nullptr);
145
146int doit(const prb_t *prb, res_t *res);
147int bench(int argc, char **argv);
148
149} // namespace reduction
150
151#endif
152