1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef SUM_HPP
18#define SUM_HPP
19
20#include <iostream>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#include "common.hpp"
25#include "dnn_types.hpp"
26#include "dnnl_common.hpp"
27#include "utils/perf_report.hpp"
28#include "utils/settings.hpp"
29
30namespace sum {
31
32struct settings_t : public base_settings_t {
33 settings_t() = default;
34
35 // ctor to save certain fields from resetting
36 settings_t(const char *perf_template) : settings_t() {
37 this->perf_template = perf_template;
38 }
39
40 prb_dims_t prb_dims;
41
42 std::vector<std::vector<dnnl_data_type_t>> sdt {{dnnl_f32, dnnl_f32}};
43 std::vector<dnnl_data_type_t> ddt {dnnl_f32};
44 std::vector<std::vector<std::string>> stag {{tag::abx}};
45 std::vector<std::string> dtag {tag::undef};
46 std::vector<std::vector<float>> input_scales {{1}};
47
48 const char *perf_template_csv() const {
49 static const std::string args = "%sdt%,%ddt%,%stag%,%dtag%";
50 return perf_template_csv_base(args);
51 }
52
53 void reset() { *this = settings_t(perf_template); }
54};
55
56struct prb_t : public prb_dims_t {
57 prb_t(const prb_dims_t &prb_dims, const std::vector<dnnl_data_type_t> &sdt,
58 dnnl_data_type_t ddt, const std::vector<std::string> &stag,
59 const std::string &dtag, const std::vector<float> &input_scales,
60 bool inplace, const attr_t &attr, const thr_ctx_t &ctx_init,
61 const thr_ctx_t &ctx_exe)
62 : prb_dims_t(prb_dims)
63 , sdt(sdt)
64 , ddt(ddt)
65 , stag(stag)
66 , dtag(dtag)
67 , input_scales(input_scales)
68 , inplace(inplace)
69 , attr(attr)
70 , ctx_init(ctx_init)
71 , ctx_exe(ctx_exe) {
72 // Broadcast tag if needed
73 if (stag.size() == 1) {
74 const auto val = stag[0]; // Need a copy here.
75 this->stag.assign(n_inputs(), val);
76 }
77
78 // Broadcast input_scale if needed
79 if (input_scales.size() == 1) {
80 const auto val = input_scales[0]; // Need a copy here.
81 this->input_scales.assign(n_inputs(), val);
82 }
83 }
84 ~prb_t() {}
85
86 dir_t dir = FLAG_FWD; // Lack of prop_kind, always considered as forward.
87 std::vector<dnnl_data_type_t> sdt;
88 dnnl_data_type_t ddt;
89 std::vector<std::string> stag;
90 std::string dtag;
91 std::vector<float> input_scales;
92 bool inplace;
93 attr_t attr;
94 thr_ctx_t ctx_init, ctx_exe;
95
96 int n_inputs() const { return (int)sdt.size(); }
97};
98std::ostream &operator<<(std::ostream &s, const prb_t &prb);
99
100struct perf_report_t : public base_perf_report_t {
101 perf_report_t(const prb_t *prb, const char *perf_template)
102 : base_perf_report_t(perf_template)
103 , p_(prb)
104 , stag_({})
105 , dtag_(normalize_tag(p_->dtag, p_->ndims)) {
106 for (size_t d = 0; d < p_->stag.size(); d++)
107 stag_.push_back(normalize_tag(p_->stag[d], p_->ndims));
108 }
109
110 void dump_desc(std::ostream &s) const override {
111 s << static_cast<const prb_dims_t &>(*p_);
112 }
113
114 void dump_desc_csv(std::ostream &s) const override { dump_desc(s); }
115
116 const attr_t *attr() const override { return &p_->attr; }
117 const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; }
118 const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; }
119 const std::string *name() const override { return &p_->name; }
120 const std::vector<dnnl_data_type_t> *sdt() const override {
121 return &p_->sdt;
122 }
123 const dnnl_data_type_t *ddt() const override { return &p_->ddt; }
124 const std::vector<std::string> *stag() const override { return &stag_; }
125 const std::string *dtag() const override { return &dtag_; }
126
127private:
128 const prb_t *p_;
129 std::vector<std::string> stag_;
130 std::string dtag_;
131};
132
133void skip_unimplemented_prb(const prb_t *prb, res_t *res);
134void skip_invalid_prb(const prb_t *prb, res_t *res);
135void compute_ref(const prb_t *prb, const args_t &args,
136 dnnl_primitive_t prim_ref = nullptr);
137
138int doit(const prb_t *prb, res_t *res);
139int bench(int argc, char **argv);
140
141} // namespace sum
142
143#endif
144