1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef PRELU_HPP
18#define PRELU_HPP
19
20#include "oneapi/dnnl/dnnl.h"
21
22#include "common.hpp"
23#include "dnn_types.hpp"
24#include "dnnl_common.hpp"
25#include "utils/perf_report.hpp"
26#include "utils/settings.hpp"
27
28namespace prelu {
29
30struct settings_t : public base_settings_t {
31 settings_t() = default;
32
33 // ctor to save certain fields from resetting
34 settings_t(const char *perf_template) : settings_t() {
35 this->perf_template = perf_template;
36 }
37
38 prb_vdims_t prb_vdims;
39
40 std::vector<dir_t> dir {FWD_D};
41 std::vector<std::vector<dnnl_data_type_t>> sdt {{dnnl_f32, dnnl_f32}};
42 std::vector<std::vector<std::string>> stag {{tag::abx, tag::any}};
43
44 const char *perf_template_csv() const {
45 static const std::string args = "%dir%,%sdt%,%stag%";
46 return perf_template_csv_base(args);
47 }
48
49 void reset() { *this = settings_t(perf_template); }
50};
51
52struct prb_t : public prb_vdims_t {
53 prb_t(const prb_vdims_t &prb_vdims, dir_t dir,
54 const std::vector<dnnl_data_type_t> &sdt,
55 const std::vector<std::string> &stag, const attr_t &attr,
56 const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe)
57 : prb_vdims_t(prb_vdims)
58 , dir(dir)
59 , sdt(sdt)
60 , stag(stag)
61 , attr(attr)
62 , ctx_init(ctx_init)
63 , ctx_exe(ctx_exe) {}
64 ~prb_t() {}
65
66 dir_t dir;
67 std::vector<dnnl_data_type_t> sdt;
68 std::vector<std::string> stag;
69 attr_t attr;
70 thr_ctx_t ctx_init, ctx_exe;
71};
72
73std::ostream &operator<<(std::ostream &s, const prb_t &prb);
74
75struct perf_report_t : public base_perf_report_t {
76 perf_report_t(const prb_t *prb, const char *perf_template)
77 : base_perf_report_t(perf_template), prb_(prb), stag_({}) {
78 for (size_t d = 0; d < prb_->stag.size(); d++)
79 stag_.push_back(normalize_tag(prb_->stag[d], prb_->ndims));
80 }
81
82 void dump_desc(std::ostream &s) const override {
83 s << static_cast<const prb_vdims_t &>(*prb_);
84 }
85
86 void dump_desc_csv(std::ostream &s) const override { dump_desc(s); }
87
88 const attr_t *attr() const override { return &prb_->attr; }
89 const thr_ctx_t *ctx_init() const override { return &prb_->ctx_init; }
90 const thr_ctx_t *ctx_exe() const override { return &prb_->ctx_exe; }
91 const std::string *name() const override { return &prb_->name; }
92 const dir_t *dir() const override { return &prb_->dir; }
93 const std::vector<dnnl_data_type_t> *sdt() const override {
94 return &prb_->sdt;
95 }
96 const std::vector<std::string> *stag() const override { return &stag_; }
97
98private:
99 const prb_t *prb_;
100 std::vector<std::string> stag_;
101};
102
103int setup_prelu_po(const_dnnl_primitive_desc_t pd, std::vector<int> &args,
104 std::vector<dnn_mem_t> &ref_mem, std::vector<dnn_mem_t> &prim_mem);
105void skip_unimplemented_prb(const prb_t *prb, res_t *res);
106void skip_invalid_prb(const prb_t *prb, res_t *res);
107void compute_ref(const prb_t *prb, const args_t &args,
108 dnnl_primitive_t prim_ref = nullptr);
109
110int doit(const prb_t *prb, res_t *res);
111int bench(int argc, char **argv);
112
113} // namespace prelu
114
115#endif
116