1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef BINARY_HPP
18#define BINARY_HPP
19
20#include <iostream>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#include "common.hpp"
25#include "dnn_types.hpp"
26#include "dnnl_common.hpp"
27#include "utils/perf_report.hpp"
28#include "utils/settings.hpp"
29
30namespace binary {
31
32using alg_t = attr_t::post_ops_t::kind_t;
33
34struct settings_t : public base_settings_t {
35 settings_t() = default;
36
37 // ctor to save certain fields from resetting
38 settings_t(const char *perf_template) : settings_t() {
39 this->perf_template = perf_template;
40 }
41
42 prb_vdims_t prb_vdims;
43
44 std::vector<std::vector<dnnl_data_type_t>> sdt {{dnnl_f32, dnnl_f32}};
45 std::vector<dnnl_data_type_t> ddt {dnnl_f32};
46 std::vector<std::vector<std::string>> stag {{tag::abx, tag::any}};
47 std::vector<std::string> dtag {tag::any};
48 std::vector<alg_t> alg {alg_t::ADD};
49
50 const char *perf_template_csv() const {
51 static const std::string args = "%sdt%,%ddt%,%stag%,%dtag%,%alg%";
52 return perf_template_csv_base(args);
53 }
54
55 void reset() { *this = settings_t(perf_template); }
56};
57
58struct prb_t : public prb_vdims_t {
59 prb_t(const prb_vdims_t &prb_vdims,
60 const std::vector<dnnl_data_type_t> &sdt, dnnl_data_type_t ddt,
61 const std::vector<std::string> &stag, std::string dtag, alg_t alg,
62 bool inplace, const attr_t &attr, const thr_ctx_t &ctx_init,
63 const thr_ctx_t &ctx_exe)
64 : prb_vdims_t(prb_vdims)
65 , sdt(sdt)
66 , ddt(ddt)
67 , stag(stag)
68 , dtag(dtag)
69 , alg(alg)
70 , inplace(inplace)
71 , attr(attr)
72 , ctx_init(ctx_init)
73 , ctx_exe(ctx_exe) {}
74 ~prb_t() {}
75
76 dir_t dir = FLAG_FWD; // Lack of prop_kind, always considered as forward.
77 std::vector<dnnl_data_type_t> sdt;
78 dnnl_data_type_t ddt;
79 std::vector<std::string> stag;
80 std::string dtag;
81 alg_t alg;
82 bool inplace;
83 attr_t attr;
84 thr_ctx_t ctx_init, ctx_exe;
85};
86std::ostream &operator<<(std::ostream &s, const prb_t &prb);
87
88struct perf_report_t : public base_perf_report_t {
89 perf_report_t(const prb_t *prb, const char *perf_template)
90 : base_perf_report_t(perf_template)
91 , p_(prb)
92 , stag_({})
93 , dtag_(normalize_tag(p_->dtag, p_->ndims)) {
94 for (size_t d = 0; d < p_->stag.size(); d++)
95 stag_.push_back(normalize_tag(p_->stag[d], p_->ndims));
96 }
97
98 void dump_alg(std::ostream &s) const override { s << p_->alg; }
99
100 void dump_desc(std::ostream &s) const override {
101 s << static_cast<const prb_vdims_t &>(*p_);
102 }
103
104 void dump_desc_csv(std::ostream &s) const override { dump_desc(s); }
105
106 const std::vector<dnnl_data_type_t> *sdt() const override {
107 return &p_->sdt;
108 }
109 const attr_t *attr() const override { return &p_->attr; }
110 const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; }
111 const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; }
112 const std::string *name() const override { return &p_->name; }
113 const dnnl_data_type_t *ddt() const override { return &p_->ddt; }
114 const std::vector<std::string> *stag() const override { return &stag_; }
115 const std::string *dtag() const override { return &dtag_; }
116
117private:
118 const prb_t *p_;
119 std::vector<std::string> stag_;
120 std::string dtag_;
121};
122
123int setup_binary_po(const_dnnl_primitive_desc_t pd, std::vector<int> &args,
124 std::vector<dnn_mem_t> &mem_dt, std::vector<dnn_mem_t> &mem_fp,
125 bool only_positive_values = false, bool only_integer_values = false);
126
127void skip_unimplemented_prb(const prb_t *prb, res_t *res);
128void skip_invalid_prb(const prb_t *prb, res_t *res);
129void compute_ref(const prb_t *prb, const args_t &args,
130 dnnl_primitive_t prim_ref = nullptr);
131
132int doit(const prb_t *prb, res_t *res);
133int bench(int argc, char **argv);
134
135} // namespace binary
136
137#endif
138