1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef ELTWISE_HPP |
18 | #define ELTWISE_HPP |
19 | |
20 | #include <iostream> |
21 | |
22 | #include "oneapi/dnnl/dnnl.h" |
23 | |
24 | #include "common.hpp" |
25 | #include "dnn_types.hpp" |
26 | #include "dnnl_common.hpp" |
27 | #include "utils/perf_report.hpp" |
28 | #include "utils/settings.hpp" |
29 | |
30 | namespace eltwise { |
31 | |
32 | using alg_t = attr_t::post_ops_t::kind_t; |
33 | |
34 | struct settings_t : public base_settings_t { |
35 | settings_t() = default; |
36 | |
37 | // ctor to save certain fields from resetting |
38 | settings_t(const char *perf_template) : settings_t() { |
39 | this->perf_template = perf_template; |
40 | } |
41 | |
42 | prb_dims_t prb_dims; |
43 | |
44 | std::vector<dir_t> dir {FWD_D}; |
45 | std::vector<dnnl_data_type_t> dt {dnnl_f32}; |
46 | std::vector<std::string> tag {tag::abx}; |
47 | std::vector<alg_t> alg {alg_t::RELU}; |
48 | std::vector<float> alpha {0.f}, beta {0.f}; |
49 | |
50 | const char *perf_template_csv() const { |
51 | static const std::string args = "%dir%,%dt%,%tag%,%alg%" ; |
52 | return perf_template_csv_base(args); |
53 | } |
54 | |
55 | void reset() { *this = settings_t(perf_template); } |
56 | }; |
57 | |
58 | struct prb_t : public prb_dims_t { |
59 | prb_t(const prb_dims_t &prb_dims, dir_t dir, dnnl_data_type_t dt, |
60 | const std::string &tag, alg_t alg, float alpha, float beta, |
61 | bool inplace, const attr_t &attr, const thr_ctx_t &ctx_init, |
62 | const thr_ctx_t &ctx_exe, int64_t mb = 0) |
63 | : prb_dims_t(prb_dims) |
64 | , dir(dir) |
65 | , dt(dt) |
66 | , tag(tag) |
67 | , alg(alg) |
68 | , alpha(alpha) |
69 | , beta(beta) |
70 | , inplace(inplace) |
71 | , attr(attr) |
72 | , ctx_init(ctx_init) |
73 | , ctx_exe(ctx_exe) |
74 | , user_mb(mb) { |
75 | if (mb) dims[0] = mb; |
76 | } |
77 | ~prb_t() {} |
78 | |
79 | dir_t dir; |
80 | dnnl_data_type_t dt; |
81 | std::string tag; |
82 | alg_t alg; |
83 | float alpha, beta; |
84 | bool inplace; |
85 | attr_t attr; |
86 | const thr_ctx_t ctx_init, ctx_exe; |
87 | int64_t user_mb; |
88 | |
89 | bool use_dst() const { |
90 | return alg == alg_t::RELU_DST || alg == alg_t::TANH_DST |
91 | || alg == alg_t::ELU_DST || alg == alg_t::SQRT_DST |
92 | || alg == alg_t::LOGISTIC_DST || alg == alg_t::EXP_DST |
93 | || alg == alg_t::CLIP_V2_DST; |
94 | } |
95 | }; |
96 | std::ostream &operator<<(std::ostream &s, const prb_t &prb); |
97 | |
98 | struct perf_report_t : public base_perf_report_t { |
99 | perf_report_t(const prb_t *prb, const char *perf_template) |
100 | : base_perf_report_t(perf_template) |
101 | , p_(prb) |
102 | , tag_(normalize_tag(p_->tag, p_->ndims)) {} |
103 | |
104 | void dump_alg(std::ostream &s) const override { s << p_->alg; } |
105 | |
106 | void dump_desc(std::ostream &s) const override { |
107 | s << static_cast<const prb_dims_t &>(*p_); |
108 | } |
109 | |
110 | void dump_desc_csv(std::ostream &s) const override { dump_desc(s); } |
111 | |
112 | const attr_t *attr() const override { return &p_->attr; } |
113 | const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; } |
114 | const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; } |
115 | const std::string *name() const override { return &p_->name; } |
116 | const dir_t *dir() const override { return &p_->dir; } |
117 | const dnnl_data_type_t *dt() const override { return &p_->dt; } |
118 | const int64_t *user_mb() const override { return &p_->user_mb; } |
119 | const std::string *tag() const override { return &tag_; } |
120 | |
121 | private: |
122 | const prb_t *p_; |
123 | std::string tag_; |
124 | }; |
125 | |
126 | float get_eltwise_threshold(dnnl_data_type_t dt, alg_t alg, bool is_fwd = true); |
127 | void skip_unimplemented_prb(const prb_t *prb, res_t *res); |
128 | void skip_invalid_prb(const prb_t *prb, res_t *res); |
129 | void compute_ref(const prb_t *prb, const args_t &args, |
130 | dnnl_primitive_t prim_ref = nullptr); |
131 | |
132 | int doit(const prb_t *prb, res_t *res); |
133 | int bench(int argc, char **argv); |
134 | |
135 | } // namespace eltwise |
136 | |
137 | #endif |
138 | |