1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CONCAT_HPP
18#define CONCAT_HPP
19
20#include <iostream>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#include "common.hpp"
25#include "dnn_types.hpp"
26#include "dnnl_common.hpp"
27#include "utils/perf_report.hpp"
28#include "utils/settings.hpp"
29
30namespace concat {
31
32struct settings_t : public base_settings_t {
33 settings_t() = default;
34
35 // ctor to save certain fields from resetting
36 settings_t(const char *perf_template) : settings_t() {
37 this->perf_template = perf_template;
38 }
39
40 prb_vdims_t prb_vdims;
41
42 std::vector<dnnl_data_type_t> sdt {dnnl_f32}, ddt {dnnl_f32};
43 std::vector<std::vector<std::string>> stag {{tag::abx}};
44 std::vector<std::string> dtag {tag::undef};
45 std::vector<int> axis {1};
46
47 const char *perf_template_csv() const {
48 static const std::string args = "%sdt%,%ddt%,%stag%,%dtag%,%axis%";
49 return perf_template_csv_base(args);
50 }
51
52 void reset() { *this = settings_t(perf_template); }
53};
54
55struct prb_t : public prb_vdims_t {
56 prb_t(const prb_vdims_t &prb_vdims, dnnl_data_type_t sdt,
57 dnnl_data_type_t ddt, const std::vector<std::string> &stag,
58 const std::string &dtag, int axis, const attr_t &attr,
59 const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe)
60 : prb_vdims_t(prb_vdims)
61 , sdt(sdt)
62 , ddt(ddt)
63 , stag(stag)
64 , dtag(dtag)
65 , axis(axis)
66 , attr(attr)
67 , ctx_init(ctx_init)
68 , ctx_exe(ctx_exe) {
69 // If dst is omitted by `dtag = tag::undef`, omit `ddt` as well.
70 if (dtag == tag::undef) this->ddt = dnnl_data_type_undef;
71
72 // Broadcast tag if needed
73 if (stag.size() == 1) {
74 const auto val = stag[0]; // Need a copy here.
75 this->stag.assign(prb_vdims.n_inputs(), val);
76 }
77
78 dst_dims[axis] = axis_size();
79 }
80 ~prb_t() {}
81
82 dir_t dir = FLAG_FWD; // Lack of prop_kind, always considered as forward.
83 dnnl_data_type_t sdt, ddt;
84 std::vector<std::string> stag;
85 std::string dtag;
86 int axis;
87 attr_t attr;
88 thr_ctx_t ctx_init, ctx_exe;
89
90 int64_t axis_size() const {
91 int64_t as = 0;
92 for (int i = 0; i < n_inputs(); ++i)
93 as += vdims[i].at(axis);
94 return as;
95 }
96};
97std::ostream &operator<<(std::ostream &s, const prb_t &prb);
98
99struct perf_report_t : public base_perf_report_t {
100 perf_report_t(const prb_t *prb, const char *perf_template)
101 : base_perf_report_t(perf_template)
102 , p_(prb)
103 , sdt_({p_->sdt})
104 , stag_({})
105 , dtag_(normalize_tag(p_->dtag, p_->ndims)) {
106 for (size_t d = 0; d < p_->stag.size(); d++)
107 stag_.push_back(normalize_tag(p_->stag[d], p_->ndims));
108 }
109
110 void dump_desc(std::ostream &s) const override {
111 s << static_cast<const prb_vdims_t &>(*p_);
112 }
113
114 void dump_desc_csv(std::ostream &s) const override { dump_desc(s); }
115
116 const attr_t *attr() const override { return &p_->attr; }
117 const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; }
118 const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; }
119 const std::string *name() const override { return &p_->name; }
120 const int *axis() const override { return &p_->axis; }
121 const std::vector<dnnl_data_type_t> *sdt() const override { return &sdt_; }
122 const dnnl_data_type_t *ddt() const override { return &p_->ddt; }
123 const std::vector<std::string> *stag() const override { return &stag_; }
124 const std::string *dtag() const override { return &dtag_; }
125
126private:
127 const prb_t *p_;
128 std::vector<dnnl_data_type_t> sdt_;
129 std::vector<std::string> stag_;
130 std::string dtag_;
131};
132
133void skip_unimplemented_prb(const prb_t *prb, res_t *res);
134void skip_invalid_prb(const prb_t *prb, res_t *res);
135void compute_ref(const prb_t *prb, const args_t &args,
136 dnnl_primitive_t prim_ref = nullptr);
137
138int doit(const prb_t *prb, res_t *res);
139int bench(int argc, char **argv);
140
141} // namespace concat
142
143#endif
144