1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef PRELU_HPP |
18 | #define PRELU_HPP |
19 | |
20 | #include "oneapi/dnnl/dnnl.h" |
21 | |
22 | #include "common.hpp" |
23 | #include "dnn_types.hpp" |
24 | #include "dnnl_common.hpp" |
25 | #include "utils/perf_report.hpp" |
26 | #include "utils/settings.hpp" |
27 | |
28 | namespace prelu { |
29 | |
30 | struct settings_t : public base_settings_t { |
31 | settings_t() = default; |
32 | |
33 | // ctor to save certain fields from resetting |
34 | settings_t(const char *perf_template) : settings_t() { |
35 | this->perf_template = perf_template; |
36 | } |
37 | |
38 | prb_vdims_t prb_vdims; |
39 | |
40 | std::vector<dir_t> dir {FWD_D}; |
41 | std::vector<std::vector<dnnl_data_type_t>> sdt {{dnnl_f32, dnnl_f32}}; |
42 | std::vector<std::vector<std::string>> stag {{tag::abx, tag::any}}; |
43 | |
44 | const char *perf_template_csv() const { |
45 | static const std::string args = "%dir%,%sdt%,%stag%" ; |
46 | return perf_template_csv_base(args); |
47 | } |
48 | |
49 | void reset() { *this = settings_t(perf_template); } |
50 | }; |
51 | |
52 | struct prb_t : public prb_vdims_t { |
53 | prb_t(const prb_vdims_t &prb_vdims, dir_t dir, |
54 | const std::vector<dnnl_data_type_t> &sdt, |
55 | const std::vector<std::string> &stag, const attr_t &attr, |
56 | const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe) |
57 | : prb_vdims_t(prb_vdims) |
58 | , dir(dir) |
59 | , sdt(sdt) |
60 | , stag(stag) |
61 | , attr(attr) |
62 | , ctx_init(ctx_init) |
63 | , ctx_exe(ctx_exe) {} |
64 | ~prb_t() {} |
65 | |
66 | dir_t dir; |
67 | std::vector<dnnl_data_type_t> sdt; |
68 | std::vector<std::string> stag; |
69 | attr_t attr; |
70 | thr_ctx_t ctx_init, ctx_exe; |
71 | }; |
72 | |
73 | std::ostream &operator<<(std::ostream &s, const prb_t &prb); |
74 | |
75 | struct perf_report_t : public base_perf_report_t { |
76 | perf_report_t(const prb_t *prb, const char *perf_template) |
77 | : base_perf_report_t(perf_template), prb_(prb), stag_({}) { |
78 | for (size_t d = 0; d < prb_->stag.size(); d++) |
79 | stag_.push_back(normalize_tag(prb_->stag[d], prb_->ndims)); |
80 | } |
81 | |
82 | void dump_desc(std::ostream &s) const override { |
83 | s << static_cast<const prb_vdims_t &>(*prb_); |
84 | } |
85 | |
86 | void dump_desc_csv(std::ostream &s) const override { dump_desc(s); } |
87 | |
88 | const attr_t *attr() const override { return &prb_->attr; } |
89 | const thr_ctx_t *ctx_init() const override { return &prb_->ctx_init; } |
90 | const thr_ctx_t *ctx_exe() const override { return &prb_->ctx_exe; } |
91 | const std::string *name() const override { return &prb_->name; } |
92 | const dir_t *dir() const override { return &prb_->dir; } |
93 | const std::vector<dnnl_data_type_t> *sdt() const override { |
94 | return &prb_->sdt; |
95 | } |
96 | const std::vector<std::string> *stag() const override { return &stag_; } |
97 | |
98 | private: |
99 | const prb_t *prb_; |
100 | std::vector<std::string> stag_; |
101 | }; |
102 | |
103 | int setup_prelu_po(const_dnnl_primitive_desc_t pd, std::vector<int> &args, |
104 | std::vector<dnn_mem_t> &ref_mem, std::vector<dnn_mem_t> &prim_mem); |
105 | void skip_unimplemented_prb(const prb_t *prb, res_t *res); |
106 | void skip_invalid_prb(const prb_t *prb, res_t *res); |
107 | void compute_ref(const prb_t *prb, const args_t &args, |
108 | dnnl_primitive_t prim_ref = nullptr); |
109 | |
110 | int doit(const prb_t *prb, res_t *res); |
111 | int bench(int argc, char **argv); |
112 | |
113 | } // namespace prelu |
114 | |
115 | #endif |
116 | |